blob: 63183b27e1beb67e8fbcd0947310ef5e4c7782de [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <mmintrin.h>
30#include <xmmintrin.h> // SSE
31#include <emmintrin.h> // SSE2
32#include <pmmintrin.h> // SSE3
33#include <smmintrin.h> // SSE4
34#include <immintrin.h> // AVX
35
36
37
38// B is the diagonal of a matrix, beta==0.0 case
39void kernel_sgemm_diag_right_4_a0_lib4(int kmax, float *alpha, float *A, int sda, float *B, float *D, int sdd)
40 {
41
42 if(kmax<=0)
43 return;
44
45 const int bs = 8;
46
47 int k;
48
49 __m256
50 alpha0,
51 mask_f,
52 sign,
53 a_00,
54 b_00, b_11, b_22, b_33,
55 d_00, d_01, d_02, d_03;
56
57 __m256i
58 mask_i;
59
60 alpha0 = _mm256_broadcast_ss( alpha );
61
62 b_00 = _mm256_broadcast_ss( &B[0] );
63 b_00 = _mm256_mul_ps( b_00, alpha0 );
64 b_11 = _mm256_broadcast_ss( &B[1] );
65 b_11 = _mm256_mul_ps( b_11, alpha0 );
66 b_22 = _mm256_broadcast_ss( &B[2] );
67 b_22 = _mm256_mul_ps( b_22, alpha0 );
68 b_33 = _mm256_broadcast_ss( &B[3] );
69 b_33 = _mm256_mul_ps( b_33, alpha0 );
70
71 for(k=0; k<kmax-7; k+=8)
72 {
73
74 a_00 = _mm256_load_ps( &A[0] );
75 d_00 = _mm256_mul_ps( a_00, b_00 );
76 a_00 = _mm256_load_ps( &A[8] );
77 d_01 = _mm256_mul_ps( a_00, b_11 );
78 a_00 = _mm256_load_ps( &A[16] );
79 d_02 = _mm256_mul_ps( a_00, b_22 );
80 a_00 = _mm256_load_ps( &A[24] );
81 d_03 = _mm256_mul_ps( a_00, b_33 );
82
83 _mm256_store_ps( &D[0], d_00 );
84 _mm256_store_ps( &D[8], d_01 );
85 _mm256_store_ps( &D[16], d_02 );
86 _mm256_store_ps( &D[24], d_03 );
87
88 A += 8*sda;
89 D += 8*sdd;
90
91 }
92 if(k<kmax)
93 {
94
95 const float mask_f[] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
96 float m_f = kmax-k;
97
98 mask_i = _mm256_castps_si256( _mm256_sub_ps( _mm256_loadu_ps( mask_f ), _mm256_broadcast_ss( &m_f ) ) );
99
100 a_00 = _mm256_load_ps( &A[0] );
101 d_00 = _mm256_mul_ps( a_00, b_00 );
102 a_00 = _mm256_load_ps( &A[8] );
103 d_01 = _mm256_mul_ps( a_00, b_11 );
104 a_00 = _mm256_load_ps( &A[16] );
105 d_02 = _mm256_mul_ps( a_00, b_22 );
106 a_00 = _mm256_load_ps( &A[24] );
107 d_03 = _mm256_mul_ps( a_00, b_33 );
108
109 _mm256_maskstore_ps( &D[0], mask_i, d_00 );
110 _mm256_maskstore_ps( &D[8], mask_i, d_01 );
111 _mm256_maskstore_ps( &D[16], mask_i, d_02 );
112 _mm256_maskstore_ps( &D[24], mask_i, d_03 );
113
114 }
115
116 }
117
118
119
120// B is the diagonal of a matrix
121void kernel_sgemm_diag_right_4_lib4(int kmax, float *alpha, float *A, int sda, float *B, float *beta, float *C, int sdc, float *D, int sdd)
122 {
123
124 if(kmax<=0)
125 return;
126
127 const int bs = 8;
128
129 int k;
130
131 __m256
132 alpha0, beta0,
133 mask_f,
134 sign,
135 a_00,
136 b_00, b_11, b_22, b_33,
137 c_00,
138 d_00, d_01, d_02, d_03;
139
140 __m256i
141 mask_i;
142
143 alpha0 = _mm256_broadcast_ss( alpha );
144 beta0 = _mm256_broadcast_ss( beta );
145
146 b_00 = _mm256_broadcast_ss( &B[0] );
147 b_00 = _mm256_mul_ps( b_00, alpha0 );
148 b_11 = _mm256_broadcast_ss( &B[1] );
149 b_11 = _mm256_mul_ps( b_11, alpha0 );
150 b_22 = _mm256_broadcast_ss( &B[2] );
151 b_22 = _mm256_mul_ps( b_22, alpha0 );
152 b_33 = _mm256_broadcast_ss( &B[3] );
153 b_33 = _mm256_mul_ps( b_33, alpha0 );
154
155 for(k=0; k<kmax-7; k+=8)
156 {
157
158 a_00 = _mm256_load_ps( &A[0] );
159 d_00 = _mm256_mul_ps( a_00, b_00 );
160 a_00 = _mm256_load_ps( &A[8] );
161 d_01 = _mm256_mul_ps( a_00, b_11 );
162 a_00 = _mm256_load_ps( &A[16] );
163 d_02 = _mm256_mul_ps( a_00, b_22 );
164 a_00 = _mm256_load_ps( &A[24] );
165 d_03 = _mm256_mul_ps( a_00, b_33 );
166
167 c_00 = _mm256_load_ps( &C[0] );
168 c_00 = _mm256_mul_ps( c_00, beta0 );
169 d_00 = _mm256_add_ps( c_00, d_00 );
170 c_00 = _mm256_load_ps( &C[8] );
171 c_00 = _mm256_mul_ps( c_00, beta0 );
172 d_01 = _mm256_add_ps( c_00, d_01 );
173 c_00 = _mm256_load_ps( &C[16] );
174 c_00 = _mm256_mul_ps( c_00, beta0 );
175 d_02 = _mm256_add_ps( c_00, d_02 );
176 c_00 = _mm256_load_ps( &C[24] );
177 c_00 = _mm256_mul_ps( c_00, beta0 );
178 d_03 = _mm256_add_ps( c_00, d_03 );
179
180 _mm256_store_ps( &D[0], d_00 );
181 _mm256_store_ps( &D[8], d_01 );
182 _mm256_store_ps( &D[16], d_02 );
183 _mm256_store_ps( &D[24], d_03 );
184
185 A += 8*sda;
186 C += 8*sdc;
187 D += 8*sdd;
188
189 }
190 if(k<kmax)
191 {
192
193 const float mask_f[] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
194 float m_f = kmax-k;
195
196 mask_i = _mm256_castps_si256( _mm256_sub_ps( _mm256_loadu_ps( mask_f ), _mm256_broadcast_ss( &m_f ) ) );
197
198 a_00 = _mm256_load_ps( &A[0] );
199 d_00 = _mm256_mul_ps( a_00, b_00 );
200 a_00 = _mm256_load_ps( &A[8] );
201 d_01 = _mm256_mul_ps( a_00, b_11 );
202 a_00 = _mm256_load_ps( &A[16] );
203 d_02 = _mm256_mul_ps( a_00, b_22 );
204 a_00 = _mm256_load_ps( &A[24] );
205 d_03 = _mm256_mul_ps( a_00, b_33 );
206
207 c_00 = _mm256_load_ps( &C[0] );
208 c_00 = _mm256_mul_ps( c_00, beta0 );
209 d_00 = _mm256_add_ps( c_00, d_00 );
210 c_00 = _mm256_load_ps( &C[8] );
211 c_00 = _mm256_mul_ps( c_00, beta0 );
212 d_01 = _mm256_add_ps( c_00, d_01 );
213 c_00 = _mm256_load_ps( &C[16] );
214 c_00 = _mm256_mul_ps( c_00, beta0 );
215 d_02 = _mm256_add_ps( c_00, d_02 );
216 c_00 = _mm256_load_ps( &C[24] );
217 c_00 = _mm256_mul_ps( c_00, beta0 );
218 d_03 = _mm256_add_ps( c_00, d_03 );
219
220 _mm256_maskstore_ps( &D[0], mask_i, d_00 );
221 _mm256_maskstore_ps( &D[8], mask_i, d_01 );
222 _mm256_maskstore_ps( &D[16], mask_i, d_02 );
223 _mm256_maskstore_ps( &D[24], mask_i, d_03 );
224
225 }
226
227 }
228
229
230
231// B is the diagonal of a matrix
232void kernel_sgemm_diag_right_3_lib4(int kmax, float *alpha, float *A, int sda, float *B, float *beta, float *C, int sdc, float *D, int sdd)
233 {
234
235 if(kmax<=0)
236 return;
237
238 const int bs = 8;
239
240 int k;
241
242 __m256
243 alpha0, beta0,
244 mask_f,
245 sign,
246 a_00,
247 b_00, b_11, b_22,
248 c_00,
249 d_00, d_01, d_02;
250
251 __m256i
252 mask_i;
253
254 alpha0 = _mm256_broadcast_ss( alpha );
255 beta0 = _mm256_broadcast_ss( beta );
256
257 b_00 = _mm256_broadcast_ss( &B[0] );
258 b_00 = _mm256_mul_ps( b_00, alpha0 );
259 b_11 = _mm256_broadcast_ss( &B[1] );
260 b_11 = _mm256_mul_ps( b_11, alpha0 );
261 b_22 = _mm256_broadcast_ss( &B[2] );
262 b_22 = _mm256_mul_ps( b_22, alpha0 );
263
264 for(k=0; k<kmax-7; k+=8)
265 {
266
267 a_00 = _mm256_load_ps( &A[0] );
268 d_00 = _mm256_mul_ps( a_00, b_00 );
269 a_00 = _mm256_load_ps( &A[8] );
270 d_01 = _mm256_mul_ps( a_00, b_11 );
271 a_00 = _mm256_load_ps( &A[16] );
272 d_02 = _mm256_mul_ps( a_00, b_22 );
273
274 c_00 = _mm256_load_ps( &C[0] );
275 c_00 = _mm256_mul_ps( c_00, beta0 );
276 d_00 = _mm256_add_ps( c_00, d_00 );
277 c_00 = _mm256_load_ps( &C[8] );
278 c_00 = _mm256_mul_ps( c_00, beta0 );
279 d_01 = _mm256_add_ps( c_00, d_01 );
280 c_00 = _mm256_load_ps( &C[16] );
281 c_00 = _mm256_mul_ps( c_00, beta0 );
282 d_02 = _mm256_add_ps( c_00, d_02 );
283
284 _mm256_store_ps( &D[0], d_00 );
285 _mm256_store_ps( &D[8], d_01 );
286 _mm256_store_ps( &D[16], d_02 );
287
288 A += 8*sda;
289 C += 8*sdc;
290 D += 8*sdd;
291
292 }
293 if(k<kmax)
294 {
295
296 const float mask_f[] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
297 float m_f = kmax-k;
298
299 mask_i = _mm256_castps_si256( _mm256_sub_ps( _mm256_loadu_ps( mask_f ), _mm256_broadcast_ss( &m_f ) ) );
300
301 a_00 = _mm256_load_ps( &A[0] );
302 d_00 = _mm256_mul_ps( a_00, b_00 );
303 a_00 = _mm256_load_ps( &A[8] );
304 d_01 = _mm256_mul_ps( a_00, b_11 );
305 a_00 = _mm256_load_ps( &A[16] );
306 d_02 = _mm256_mul_ps( a_00, b_22 );
307
308 c_00 = _mm256_load_ps( &C[0] );
309 c_00 = _mm256_mul_ps( c_00, beta0 );
310 d_00 = _mm256_add_ps( c_00, d_00 );
311 c_00 = _mm256_load_ps( &C[8] );
312 c_00 = _mm256_mul_ps( c_00, beta0 );
313 d_01 = _mm256_add_ps( c_00, d_01 );
314 c_00 = _mm256_load_ps( &C[16] );
315 c_00 = _mm256_mul_ps( c_00, beta0 );
316 d_02 = _mm256_add_ps( c_00, d_02 );
317
318 _mm256_maskstore_ps( &D[0], mask_i, d_00 );
319 _mm256_maskstore_ps( &D[8], mask_i, d_01 );
320 _mm256_maskstore_ps( &D[16], mask_i, d_02 );
321
322 }
323
324 }
325
326
327
328// B is the diagonal of a matrix
329void kernel_sgemm_diag_right_2_lib4(int kmax, float *alpha, float *A, int sda, float *B, float *beta, float *C, int sdc, float *D, int sdd)
330 {
331
332 if(kmax<=0)
333 return;
334
335 const int bs = 4;
336
337 int k;
338
339 __m256
340 alpha0, beta0,
341 mask_f,
342 sign,
343 a_00,
344 b_00, b_11,
345 c_00,
346 d_00, d_01;
347
348 __m256i
349 mask_i;
350
351 alpha0 = _mm256_broadcast_ss( alpha );
352 beta0 = _mm256_broadcast_ss( beta );
353
354 b_00 = _mm256_broadcast_ss( &B[0] );
355 b_00 = _mm256_mul_ps( b_00, alpha0 );
356 b_11 = _mm256_broadcast_ss( &B[1] );
357 b_11 = _mm256_mul_ps( b_11, alpha0 );
358
359 for(k=0; k<kmax-7; k+=8)
360 {
361
362 a_00 = _mm256_load_ps( &A[0] );
363 d_00 = _mm256_mul_ps( a_00, b_00 );
364 a_00 = _mm256_load_ps( &A[8] );
365 d_01 = _mm256_mul_ps( a_00, b_11 );
366
367 c_00 = _mm256_load_ps( &C[0] );
368 c_00 = _mm256_mul_ps( c_00, beta0 );
369 d_00 = _mm256_add_ps( c_00, d_00 );
370 c_00 = _mm256_load_ps( &C[8] );
371 c_00 = _mm256_mul_ps( c_00, beta0 );
372 d_01 = _mm256_add_ps( c_00, d_01 );
373
374 _mm256_store_ps( &D[0], d_00 );
375 _mm256_store_ps( &D[8], d_01 );
376
377 A += 8*sda;
378 C += 8*sdc;
379 D += 8*sdd;
380
381 }
382 if(k<kmax)
383 {
384
385 const float mask_f[] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
386 float m_f = kmax-k;
387
388 mask_i = _mm256_castps_si256( _mm256_sub_ps( _mm256_loadu_ps( mask_f ), _mm256_broadcast_ss( &m_f ) ) );
389
390 a_00 = _mm256_load_ps( &A[0] );
391 d_00 = _mm256_mul_ps( a_00, b_00 );
392 a_00 = _mm256_load_ps( &A[8] );
393 d_01 = _mm256_mul_ps( a_00, b_11 );
394
395 c_00 = _mm256_load_ps( &C[0] );
396 c_00 = _mm256_mul_ps( c_00, beta0 );
397 d_00 = _mm256_add_ps( c_00, d_00 );
398 c_00 = _mm256_load_ps( &C[8] );
399 c_00 = _mm256_mul_ps( c_00, beta0 );
400 d_01 = _mm256_add_ps( c_00, d_01 );
401
402 _mm256_maskstore_ps( &D[0], mask_i, d_00 );
403 _mm256_maskstore_ps( &D[8], mask_i, d_01 );
404
405 }
406
407 }
408
409
410
411// B is the diagonal of a matrix
412void kernel_sgemm_diag_right_1_lib4(int kmax, float *alpha, float *A, int sda, float *B, float *beta, float *C, int sdc, float *D, int sdd)
413 {
414
415 if(kmax<=0)
416 return;
417
418 const int bs = 4;
419
420 int k;
421
422 __m256
423 alpha0, beta0,
424 mask_f,
425 sign,
426 a_00,
427 b_00,
428 c_00,
429 d_00;
430
431 __m256i
432 mask_i;
433
434 alpha0 = _mm256_broadcast_ss( alpha );
435 beta0 = _mm256_broadcast_ss( beta );
436
437 b_00 = _mm256_broadcast_ss( &B[0] );
438 b_00 = _mm256_mul_ps( b_00, alpha0 );
439
440 for(k=0; k<kmax-7; k+=8)
441 {
442
443 a_00 = _mm256_load_ps( &A[0] );
444 d_00 = _mm256_mul_ps( a_00, b_00 );
445
446 c_00 = _mm256_load_ps( &C[0] );
447 c_00 = _mm256_mul_ps( c_00, beta0 );
448 d_00 = _mm256_add_ps( c_00, d_00 );
449
450 _mm256_store_ps( &D[0], d_00 );
451
452 A += 8*sda;
453 C += 8*sdc;
454 D += 8*sdd;
455
456 }
457 if(k<kmax)
458 {
459
460 const float mask_f[] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
461 float m_f = kmax-k;
462
463 mask_i = _mm256_castps_si256( _mm256_sub_ps( _mm256_loadu_ps( mask_f ), _mm256_broadcast_ss( &m_f ) ) );
464
465 a_00 = _mm256_load_ps( &A[0] );
466 d_00 = _mm256_mul_ps( a_00, b_00 );
467
468 c_00 = _mm256_load_ps( &C[0] );
469 c_00 = _mm256_mul_ps( c_00, beta0 );
470 d_00 = _mm256_add_ps( c_00, d_00 );
471
472 _mm256_maskstore_ps( &D[0], mask_i, d_00 );
473
474 }
475
476 }
477
478
479
480