blob: d64f9770954d472d69d9aef4f79b89c0afa87fa8 [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <mmintrin.h>
30#include <xmmintrin.h> // SSE
31#include <emmintrin.h> // SSE2
32#include <pmmintrin.h> // SSE3
33#include <smmintrin.h> // SSE4
34#include <immintrin.h> // AVX
35
36
37
38// B is the diagonal of a matrix, beta==0.0 case
39void kernel_dgemm_diag_right_4_a0_lib4(int kmax, double *alpha, double *A, int sda, double *B, double *D, int sdd)
40 {
41
42 if(kmax<=0)
43 return;
44
45 const int bs = 4;
46
47 int k;
48
49 __m256d
50 alpha0,
51 mask_f,
52 sign,
53 a_00,
54 b_00, b_11, b_22, b_33,
55 d_00, d_01, d_02, d_03;
56
57 __m256i
58 mask_i;
59
60 alpha0 = _mm256_broadcast_sd( alpha );
61
62 b_00 = _mm256_broadcast_sd( &B[0] );
63 b_00 = _mm256_mul_pd( b_00, alpha0 );
64 b_11 = _mm256_broadcast_sd( &B[1] );
65 b_11 = _mm256_mul_pd( b_11, alpha0 );
66 b_22 = _mm256_broadcast_sd( &B[2] );
67 b_22 = _mm256_mul_pd( b_22, alpha0 );
68 b_33 = _mm256_broadcast_sd( &B[3] );
69 b_33 = _mm256_mul_pd( b_33, alpha0 );
70
71 for(k=0; k<kmax-3; k+=4)
72 {
73
74 a_00 = _mm256_load_pd( &A[0] );
75 d_00 = _mm256_mul_pd( a_00, b_00 );
76 a_00 = _mm256_load_pd( &A[4] );
77 d_01 = _mm256_mul_pd( a_00, b_11 );
78 a_00 = _mm256_load_pd( &A[8] );
79 d_02 = _mm256_mul_pd( a_00, b_22 );
80 a_00 = _mm256_load_pd( &A[12] );
81 d_03 = _mm256_mul_pd( a_00, b_33 );
82
83 _mm256_store_pd( &D[0], d_00 );
84 _mm256_store_pd( &D[4], d_01 );
85 _mm256_store_pd( &D[8], d_02 );
86 _mm256_store_pd( &D[12], d_03 );
87
88 A += 4*sda;
89 D += 4*sdd;
90
91 }
92 if(k<kmax)
93 {
94
95 const double mask_f[] = {0.5, 1.5, 2.5, 3.5};
96 double m_f = kmax-k;
97
98 mask_i = _mm256_castpd_si256( _mm256_sub_pd( _mm256_loadu_pd( mask_f ), _mm256_broadcast_sd( &m_f ) ) );
99
100 a_00 = _mm256_load_pd( &A[0] );
101 d_00 = _mm256_mul_pd( a_00, b_00 );
102 a_00 = _mm256_load_pd( &A[4] );
103 d_01 = _mm256_mul_pd( a_00, b_11 );
104 a_00 = _mm256_load_pd( &A[8] );
105 d_02 = _mm256_mul_pd( a_00, b_22 );
106 a_00 = _mm256_load_pd( &A[12] );
107 d_03 = _mm256_mul_pd( a_00, b_33 );
108
109 _mm256_maskstore_pd( &D[0], mask_i, d_00 );
110 _mm256_maskstore_pd( &D[4], mask_i, d_01 );
111 _mm256_maskstore_pd( &D[8], mask_i, d_02 );
112 _mm256_maskstore_pd( &D[12], mask_i, d_03 );
113
114 }
115
116 }
117
118
119
120// B is the diagonal of a matrix
121void kernel_dgemm_diag_right_4_lib4(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd)
122 {
123
124 if(kmax<=0)
125 return;
126
127 const int bs = 4;
128
129 int k;
130
131 __m256d
132 alpha0, beta0,
133 mask_f,
134 sign,
135 a_00,
136 b_00, b_11, b_22, b_33,
137 c_00,
138 d_00, d_01, d_02, d_03;
139
140 __m256i
141 mask_i;
142
143 alpha0 = _mm256_broadcast_sd( alpha );
144 beta0 = _mm256_broadcast_sd( beta );
145
146 b_00 = _mm256_broadcast_sd( &B[0] );
147 b_00 = _mm256_mul_pd( b_00, alpha0 );
148 b_11 = _mm256_broadcast_sd( &B[1] );
149 b_11 = _mm256_mul_pd( b_11, alpha0 );
150 b_22 = _mm256_broadcast_sd( &B[2] );
151 b_22 = _mm256_mul_pd( b_22, alpha0 );
152 b_33 = _mm256_broadcast_sd( &B[3] );
153 b_33 = _mm256_mul_pd( b_33, alpha0 );
154
155 for(k=0; k<kmax-3; k+=4)
156 {
157
158 a_00 = _mm256_load_pd( &A[0] );
159 d_00 = _mm256_mul_pd( a_00, b_00 );
160 a_00 = _mm256_load_pd( &A[4] );
161 d_01 = _mm256_mul_pd( a_00, b_11 );
162 a_00 = _mm256_load_pd( &A[8] );
163 d_02 = _mm256_mul_pd( a_00, b_22 );
164 a_00 = _mm256_load_pd( &A[12] );
165 d_03 = _mm256_mul_pd( a_00, b_33 );
166
167 c_00 = _mm256_load_pd( &C[0] );
168 c_00 = _mm256_mul_pd( c_00, beta0 );
169 d_00 = _mm256_add_pd( c_00, d_00 );
170 c_00 = _mm256_load_pd( &C[4] );
171 c_00 = _mm256_mul_pd( c_00, beta0 );
172 d_01 = _mm256_add_pd( c_00, d_01 );
173 c_00 = _mm256_load_pd( &C[8] );
174 c_00 = _mm256_mul_pd( c_00, beta0 );
175 d_02 = _mm256_add_pd( c_00, d_02 );
176 c_00 = _mm256_load_pd( &C[12] );
177 c_00 = _mm256_mul_pd( c_00, beta0 );
178 d_03 = _mm256_add_pd( c_00, d_03 );
179
180 _mm256_store_pd( &D[0], d_00 );
181 _mm256_store_pd( &D[4], d_01 );
182 _mm256_store_pd( &D[8], d_02 );
183 _mm256_store_pd( &D[12], d_03 );
184
185 A += 4*sda;
186 C += 4*sdc;
187 D += 4*sdd;
188
189 }
190 if(k<kmax)
191 {
192
193 const double mask_f[] = {0.5, 1.5, 2.5, 3.5};
194 double m_f = kmax-k;
195
196 mask_i = _mm256_castpd_si256( _mm256_sub_pd( _mm256_loadu_pd( mask_f ), _mm256_broadcast_sd( &m_f ) ) );
197
198 a_00 = _mm256_load_pd( &A[0] );
199 d_00 = _mm256_mul_pd( a_00, b_00 );
200 a_00 = _mm256_load_pd( &A[4] );
201 d_01 = _mm256_mul_pd( a_00, b_11 );
202 a_00 = _mm256_load_pd( &A[8] );
203 d_02 = _mm256_mul_pd( a_00, b_22 );
204 a_00 = _mm256_load_pd( &A[12] );
205 d_03 = _mm256_mul_pd( a_00, b_33 );
206
207 c_00 = _mm256_load_pd( &C[0] );
208 c_00 = _mm256_mul_pd( c_00, beta0 );
209 d_00 = _mm256_add_pd( c_00, d_00 );
210 c_00 = _mm256_load_pd( &C[4] );
211 c_00 = _mm256_mul_pd( c_00, beta0 );
212 d_01 = _mm256_add_pd( c_00, d_01 );
213 c_00 = _mm256_load_pd( &C[8] );
214 c_00 = _mm256_mul_pd( c_00, beta0 );
215 d_02 = _mm256_add_pd( c_00, d_02 );
216 c_00 = _mm256_load_pd( &C[12] );
217 c_00 = _mm256_mul_pd( c_00, beta0 );
218 d_03 = _mm256_add_pd( c_00, d_03 );
219
220 _mm256_maskstore_pd( &D[0], mask_i, d_00 );
221 _mm256_maskstore_pd( &D[4], mask_i, d_01 );
222 _mm256_maskstore_pd( &D[8], mask_i, d_02 );
223 _mm256_maskstore_pd( &D[12], mask_i, d_03 );
224
225 }
226
227 }
228
229
230
231// B is the diagonal of a matrix
232void kernel_dgemm_diag_right_3_lib4(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd)
233 {
234
235 if(kmax<=0)
236 return;
237
238 const int bs = 4;
239
240 int k;
241
242 __m256d
243 alpha0, beta0,
244 mask_f,
245 sign,
246 a_00,
247 b_00, b_11, b_22,
248 c_00,
249 d_00, d_01, d_02;
250
251 __m256i
252 mask_i;
253
254 alpha0 = _mm256_broadcast_sd( alpha );
255 beta0 = _mm256_broadcast_sd( beta );
256
257 b_00 = _mm256_broadcast_sd( &B[0] );
258 b_00 = _mm256_mul_pd( b_00, alpha0 );
259 b_11 = _mm256_broadcast_sd( &B[1] );
260 b_11 = _mm256_mul_pd( b_11, alpha0 );
261 b_22 = _mm256_broadcast_sd( &B[2] );
262 b_22 = _mm256_mul_pd( b_22, alpha0 );
263
264 for(k=0; k<kmax-3; k+=4)
265 {
266
267 a_00 = _mm256_load_pd( &A[0] );
268 d_00 = _mm256_mul_pd( a_00, b_00 );
269 a_00 = _mm256_load_pd( &A[4] );
270 d_01 = _mm256_mul_pd( a_00, b_11 );
271 a_00 = _mm256_load_pd( &A[8] );
272 d_02 = _mm256_mul_pd( a_00, b_22 );
273
274 c_00 = _mm256_load_pd( &C[0] );
275 c_00 = _mm256_mul_pd( c_00, beta0 );
276 d_00 = _mm256_add_pd( c_00, d_00 );
277 c_00 = _mm256_load_pd( &C[4] );
278 c_00 = _mm256_mul_pd( c_00, beta0 );
279 d_01 = _mm256_add_pd( c_00, d_01 );
280 c_00 = _mm256_load_pd( &C[8] );
281 c_00 = _mm256_mul_pd( c_00, beta0 );
282 d_02 = _mm256_add_pd( c_00, d_02 );
283
284 _mm256_store_pd( &D[0], d_00 );
285 _mm256_store_pd( &D[4], d_01 );
286 _mm256_store_pd( &D[8], d_02 );
287
288 A += 4*sda;
289 C += 4*sdc;
290 D += 4*sdd;
291
292 }
293 if(k<kmax)
294 {
295
296 const double mask_f[] = {0.5, 1.5, 2.5, 3.5};
297 double m_f = kmax-k;
298
299 mask_i = _mm256_castpd_si256( _mm256_sub_pd( _mm256_loadu_pd( mask_f ), _mm256_broadcast_sd( &m_f ) ) );
300
301 a_00 = _mm256_load_pd( &A[0] );
302 d_00 = _mm256_mul_pd( a_00, b_00 );
303 a_00 = _mm256_load_pd( &A[4] );
304 d_01 = _mm256_mul_pd( a_00, b_11 );
305 a_00 = _mm256_load_pd( &A[8] );
306 d_02 = _mm256_mul_pd( a_00, b_22 );
307
308 c_00 = _mm256_load_pd( &C[0] );
309 c_00 = _mm256_mul_pd( c_00, beta0 );
310 d_00 = _mm256_add_pd( c_00, d_00 );
311 c_00 = _mm256_load_pd( &C[4] );
312 c_00 = _mm256_mul_pd( c_00, beta0 );
313 d_01 = _mm256_add_pd( c_00, d_01 );
314 c_00 = _mm256_load_pd( &C[8] );
315 c_00 = _mm256_mul_pd( c_00, beta0 );
316 d_02 = _mm256_add_pd( c_00, d_02 );
317
318 _mm256_maskstore_pd( &D[0], mask_i, d_00 );
319 _mm256_maskstore_pd( &D[4], mask_i, d_01 );
320 _mm256_maskstore_pd( &D[8], mask_i, d_02 );
321
322 }
323
324 }
325
326
327
328// B is the diagonal of a matrix
329void kernel_dgemm_diag_right_2_lib4(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd)
330 {
331
332 if(kmax<=0)
333 return;
334
335 const int bs = 4;
336
337 int k;
338
339 __m256d
340 alpha0, beta0,
341 mask_f,
342 sign,
343 a_00,
344 b_00, b_11,
345 c_00,
346 d_00, d_01;
347
348 __m256i
349 mask_i;
350
351 alpha0 = _mm256_broadcast_sd( alpha );
352 beta0 = _mm256_broadcast_sd( beta );
353
354 b_00 = _mm256_broadcast_sd( &B[0] );
355 b_00 = _mm256_mul_pd( b_00, alpha0 );
356 b_11 = _mm256_broadcast_sd( &B[1] );
357 b_11 = _mm256_mul_pd( b_11, alpha0 );
358
359 for(k=0; k<kmax-3; k+=4)
360 {
361
362 a_00 = _mm256_load_pd( &A[0] );
363 d_00 = _mm256_mul_pd( a_00, b_00 );
364 a_00 = _mm256_load_pd( &A[4] );
365 d_01 = _mm256_mul_pd( a_00, b_11 );
366
367 c_00 = _mm256_load_pd( &C[0] );
368 c_00 = _mm256_mul_pd( c_00, beta0 );
369 d_00 = _mm256_add_pd( c_00, d_00 );
370 c_00 = _mm256_load_pd( &C[4] );
371 c_00 = _mm256_mul_pd( c_00, beta0 );
372 d_01 = _mm256_add_pd( c_00, d_01 );
373
374 _mm256_store_pd( &D[0], d_00 );
375 _mm256_store_pd( &D[4], d_01 );
376
377 A += 4*sda;
378 C += 4*sdc;
379 D += 4*sdd;
380
381 }
382 if(k<kmax)
383 {
384
385 const double mask_f[] = {0.5, 1.5, 2.5, 3.5};
386 double m_f = kmax-k;
387
388 mask_i = _mm256_castpd_si256( _mm256_sub_pd( _mm256_loadu_pd( mask_f ), _mm256_broadcast_sd( &m_f ) ) );
389
390 a_00 = _mm256_load_pd( &A[0] );
391 d_00 = _mm256_mul_pd( a_00, b_00 );
392 a_00 = _mm256_load_pd( &A[4] );
393 d_01 = _mm256_mul_pd( a_00, b_11 );
394
395 c_00 = _mm256_load_pd( &C[0] );
396 c_00 = _mm256_mul_pd( c_00, beta0 );
397 d_00 = _mm256_add_pd( c_00, d_00 );
398 c_00 = _mm256_load_pd( &C[4] );
399 c_00 = _mm256_mul_pd( c_00, beta0 );
400 d_01 = _mm256_add_pd( c_00, d_01 );
401
402 _mm256_maskstore_pd( &D[0], mask_i, d_00 );
403 _mm256_maskstore_pd( &D[4], mask_i, d_01 );
404
405 }
406
407 }
408
409
410
411// B is the diagonal of a matrix
412void kernel_dgemm_diag_right_1_lib4(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd)
413 {
414
415 if(kmax<=0)
416 return;
417
418 const int bs = 4;
419
420 int k;
421
422 __m256d
423 alpha0, beta0,
424 mask_f,
425 sign,
426 a_00,
427 b_00,
428 c_00,
429 d_00;
430
431 __m256i
432 mask_i;
433
434 alpha0 = _mm256_broadcast_sd( alpha );
435 beta0 = _mm256_broadcast_sd( beta );
436
437 b_00 = _mm256_broadcast_sd( &B[0] );
438
439 for(k=0; k<kmax-3; k+=4)
440 {
441
442 a_00 = _mm256_load_pd( &A[0] );
443 d_00 = _mm256_mul_pd( a_00, b_00 );
444
445 c_00 = _mm256_load_pd( &C[0] );
446 c_00 = _mm256_mul_pd( c_00, beta0 );
447 d_00 = _mm256_add_pd( c_00, d_00 );
448
449 _mm256_store_pd( &D[0], d_00 );
450
451 A += 4*sda;
452 C += 4*sdc;
453 D += 4*sdd;
454
455 }
456 if(k<kmax)
457 {
458
459 const double mask_f[] = {0.5, 1.5, 2.5, 3.5};
460 double m_f = kmax-k;
461
462 mask_i = _mm256_castpd_si256( _mm256_sub_pd( _mm256_loadu_pd( mask_f ), _mm256_broadcast_sd( &m_f ) ) );
463
464 a_00 = _mm256_load_pd( &A[0] );
465 d_00 = _mm256_mul_pd( a_00, b_00 );
466
467 c_00 = _mm256_load_pd( &C[0] );
468 c_00 = _mm256_mul_pd( c_00, beta0 );
469 d_00 = _mm256_add_pd( c_00, d_00 );
470
471 _mm256_maskstore_pd( &D[0], mask_i, d_00 );
472
473 }
474
475 }
476
477
478
479// A is the diagonal of a matrix, beta=0.0 case
480void kernel_dgemm_diag_left_4_a0_lib4(int kmax, double *alpha, double *A, double *B, double *D)
481 {
482
483 if(kmax<=0)
484 return;
485
486 const int bs = 4;
487
488 int k;
489
490 __m256d
491 alpha0,
492 sign,
493 a_00,
494 b_00,
495 d_00, d_01, d_02, d_03;
496
497 alpha0 = _mm256_broadcast_sd( alpha );
498
499 a_00 = _mm256_load_pd( &A[0] );
500 a_00 = _mm256_mul_pd( a_00, alpha0 );
501
502 for(k=0; k<kmax-3; k+=4)
503 {
504
505 b_00 = _mm256_load_pd( &B[0] );
506 d_00 = _mm256_mul_pd( a_00, b_00 );
507 b_00 = _mm256_load_pd( &B[4] );
508 d_01 = _mm256_mul_pd( a_00, b_00 );
509 b_00 = _mm256_load_pd( &B[8] );
510 d_02 = _mm256_mul_pd( a_00, b_00 );
511 b_00 = _mm256_load_pd( &B[12] );
512 d_03 = _mm256_mul_pd( a_00, b_00 );
513
514 _mm256_store_pd( &D[0], d_00 );
515 _mm256_store_pd( &D[4], d_01 );
516 _mm256_store_pd( &D[8], d_02 );
517 _mm256_store_pd( &D[12], d_03 );
518
519 B += 16;
520 D += 16;
521
522 }
523 for(; k<kmax; k++)
524 {
525
526 b_00 = _mm256_load_pd( &B[0] );
527 d_00 = _mm256_mul_pd( a_00, b_00 );
528
529 _mm256_store_pd( &D[0], d_00 );
530
531 B += 4;
532 D += 4;
533
534 }
535
536 }
537
538
539
540// A is the diagonal of a matrix
541void kernel_dgemm_diag_left_4_lib4(int kmax, double *alpha, double *A, double *B, double *beta, double *C, double *D)
542 {
543
544 if(kmax<=0)
545 return;
546
547 const int bs = 4;
548
549 int k;
550
551 __m256d
552 alpha0, beta0,
553 sign,
554 a_00,
555 b_00,
556 c_00,
557 d_00, d_01, d_02, d_03;
558
559 alpha0 = _mm256_broadcast_sd( alpha );
560 beta0 = _mm256_broadcast_sd( beta );
561
562 a_00 = _mm256_load_pd( &A[0] );
563 a_00 = _mm256_mul_pd( a_00, alpha0 );
564
565 for(k=0; k<kmax-3; k+=4)
566 {
567
568 b_00 = _mm256_load_pd( &B[0] );
569 d_00 = _mm256_mul_pd( a_00, b_00 );
570 b_00 = _mm256_load_pd( &B[4] );
571 d_01 = _mm256_mul_pd( a_00, b_00 );
572 b_00 = _mm256_load_pd( &B[8] );
573 d_02 = _mm256_mul_pd( a_00, b_00 );
574 b_00 = _mm256_load_pd( &B[12] );
575 d_03 = _mm256_mul_pd( a_00, b_00 );
576
577 c_00 = _mm256_load_pd( &C[0] );
578 c_00 = _mm256_mul_pd( c_00, beta0 );
579 d_00 = _mm256_add_pd( c_00, d_00 );
580 c_00 = _mm256_load_pd( &C[4] );
581 c_00 = _mm256_mul_pd( c_00, beta0 );
582 d_01 = _mm256_add_pd( c_00, d_01 );
583 c_00 = _mm256_load_pd( &C[8] );
584 c_00 = _mm256_mul_pd( c_00, beta0 );
585 d_02 = _mm256_add_pd( c_00, d_02 );
586 c_00 = _mm256_load_pd( &C[12] );
587 c_00 = _mm256_mul_pd( c_00, beta0 );
588 d_03 = _mm256_add_pd( c_00, d_03 );
589
590 _mm256_store_pd( &D[0], d_00 );
591 _mm256_store_pd( &D[4], d_01 );
592 _mm256_store_pd( &D[8], d_02 );
593 _mm256_store_pd( &D[12], d_03 );
594
595 B += 16;
596 C += 16;
597 D += 16;
598
599 }
600 for(; k<kmax; k++)
601 {
602
603 b_00 = _mm256_load_pd( &B[0] );
604 d_00 = _mm256_mul_pd( a_00, b_00 );
605
606 c_00 = _mm256_load_pd( &C[0] );
607 c_00 = _mm256_mul_pd( c_00, beta0 );
608 d_00 = _mm256_add_pd( c_00, d_00 );
609
610 _mm256_store_pd( &D[0], d_00 );
611
612 B += 4;
613 C += 4;
614 D += 4;
615
616 }
617
618 }
619
620
621
622// A is the diagonal of a matrix
623void kernel_dgemm_diag_left_3_lib4(int kmax, double *alpha, double *A, double *B, double *beta, double *C, double *D)
624 {
625
626 if(kmax<=0)
627 return;
628
629 const int bs = 4;
630
631 int k;
632
633 __m256i
634 mask;
635
636 __m256d
637 alpha0, beta0,
638 sign,
639 a_00,
640 b_00,
641 c_00,
642 d_00, d_01, d_02, d_03;
643
644 mask = _mm256_set_epi64x( 1, -1, -1, -1 );
645
646 alpha0 = _mm256_broadcast_sd( alpha );
647 beta0 = _mm256_broadcast_sd( beta );
648
649 a_00 = _mm256_load_pd( &A[0] );
650 a_00 = _mm256_mul_pd( a_00, alpha0 );
651
652 for(k=0; k<kmax-3; k+=4)
653 {
654
655 b_00 = _mm256_load_pd( &B[0] );
656 d_00 = _mm256_mul_pd( a_00, b_00 );
657 b_00 = _mm256_load_pd( &B[4] );
658 d_01 = _mm256_mul_pd( a_00, b_00 );
659 b_00 = _mm256_load_pd( &B[8] );
660 d_02 = _mm256_mul_pd( a_00, b_00 );
661 b_00 = _mm256_load_pd( &B[12] );
662 d_03 = _mm256_mul_pd( a_00, b_00 );
663
664 c_00 = _mm256_load_pd( &C[0] );
665 c_00 = _mm256_mul_pd( c_00, beta0 );
666 d_00 = _mm256_add_pd( c_00, d_00 );
667 c_00 = _mm256_load_pd( &C[4] );
668 c_00 = _mm256_mul_pd( c_00, beta0 );
669 d_01 = _mm256_add_pd( c_00, d_01 );
670 c_00 = _mm256_load_pd( &C[8] );
671 c_00 = _mm256_mul_pd( c_00, beta0 );
672 d_02 = _mm256_add_pd( c_00, d_02 );
673 c_00 = _mm256_load_pd( &C[12] );
674 c_00 = _mm256_mul_pd( c_00, beta0 );
675 d_03 = _mm256_add_pd( c_00, d_03 );
676
677 _mm256_maskstore_pd( &D[0], mask, d_00 );
678 _mm256_maskstore_pd( &D[4], mask, d_01 );
679 _mm256_maskstore_pd( &D[8], mask, d_02 );
680 _mm256_maskstore_pd( &D[12], mask, d_03 );
681
682 B += 16;
683 C += 16;
684 D += 16;
685
686 }
687 for(; k<kmax; k++)
688 {
689
690 b_00 = _mm256_load_pd( &B[0] );
691 d_00 = _mm256_mul_pd( a_00, b_00 );
692
693 c_00 = _mm256_load_pd( &C[0] );
694 c_00 = _mm256_mul_pd( c_00, beta0 );
695 d_00 = _mm256_add_pd( c_00, d_00 );
696
697 _mm256_maskstore_pd( &D[0], mask, d_00 );
698
699 B += 4;
700 C += 4;
701 D += 4;
702
703 }
704
705 }
706
707
708
709// A is the diagonal of a matrix
710void kernel_dgemm_diag_left_2_lib4(int kmax, double *alpha, double *A, double *B, double *beta, double *C, double *D)
711 {
712
713 if(kmax<=0)
714 return;
715
716 const int bs = 4;
717
718 int k;
719
720 __m128d
721 alpha0, beta0,
722 sign,
723 a_00,
724 b_00,
725 c_00,
726 d_00, d_01, d_02, d_03;
727
728 alpha0 = _mm_loaddup_pd( alpha );
729 beta0 = _mm_loaddup_pd( beta );
730
731 a_00 = _mm_load_pd( &A[0] );
732 a_00 = _mm_mul_pd( a_00, alpha0 );
733
734 for(k=0; k<kmax-3; k+=4)
735 {
736
737 b_00 = _mm_load_pd( &B[0] );
738 d_00 = _mm_mul_pd( a_00, b_00 );
739 b_00 = _mm_load_pd( &B[4] );
740 d_01 = _mm_mul_pd( a_00, b_00 );
741 b_00 = _mm_load_pd( &B[8] );
742 d_02 = _mm_mul_pd( a_00, b_00 );
743 b_00 = _mm_load_pd( &B[12] );
744 d_03 = _mm_mul_pd( a_00, b_00 );
745
746 c_00 = _mm_load_pd( &C[0] );
747 c_00 = _mm_mul_pd( c_00, beta0 );
748 d_00 = _mm_add_pd( c_00, d_00 );
749 c_00 = _mm_load_pd( &C[4] );
750 c_00 = _mm_mul_pd( c_00, beta0 );
751 d_01 = _mm_add_pd( c_00, d_01 );
752 c_00 = _mm_load_pd( &C[8] );
753 c_00 = _mm_mul_pd( c_00, beta0 );
754 d_02 = _mm_add_pd( c_00, d_02 );
755 c_00 = _mm_load_pd( &C[12] );
756 c_00 = _mm_mul_pd( c_00, beta0 );
757 d_03 = _mm_add_pd( c_00, d_03 );
758
759 _mm_store_pd( &D[0], d_00 );
760 _mm_store_pd( &D[4], d_01 );
761 _mm_store_pd( &D[8], d_02 );
762 _mm_store_pd( &D[12], d_03 );
763
764 B += 16;
765 C += 16;
766 D += 16;
767
768 }
769 for(; k<kmax; k++)
770 {
771
772 b_00 = _mm_load_pd( &B[0] );
773 d_00 = _mm_mul_pd( a_00, b_00 );
774
775 c_00 = _mm_load_pd( &C[0] );
776 c_00 = _mm_mul_pd( c_00, beta0 );
777 d_00 = _mm_add_pd( c_00, d_00 );
778
779 _mm_store_pd( &D[0], d_00 );
780
781 B += 4;
782 C += 4;
783 D += 4;
784
785 }
786
787
788 }
789
790
791// A is the diagonal of a matrix
792void kernel_dgemm_diag_left_1_lib4(int kmax, double *alpha, double *A, double *B, double *beta, double *C, double *D)
793 {
794
795 if(kmax<=0)
796 return;
797
798 const int bs = 4;
799
800 int k;
801
802 double
803 alpha0, beta0,
804 a_0,
805 b_0,
806 c_0;
807
808 alpha0 = alpha[0];
809 beta0 = beta[0];
810
811 a_0 = A[0] * alpha0;
812
813 for(k=0; k<kmax-3; k+=4)
814 {
815
816 b_0 = B[0+bs*0];
817
818 c_0 = beta0 * C[0+bs*0] + a_0 * b_0;
819
820 D[0+bs*0] = c_0;
821
822
823 b_0 = B[0+bs*1];
824
825 c_0 = beta0 * C[0+bs*1] + a_0 * b_0;
826
827 D[0+bs*1] = c_0;
828
829
830 b_0 = B[0+bs*2];
831
832 c_0 = beta0 * C[0+bs*2] + a_0 * b_0;
833
834 D[0+bs*2] = c_0;
835
836
837 b_0 = B[0+bs*3];
838
839 c_0 = beta0 * C[0+bs*3] + a_0 * b_0;
840
841 D[0+bs*3] = c_0;
842
843 B += 16;
844 C += 16;
845 D += 16;
846
847 }
848 for(; k<kmax; k++)
849 {
850
851 b_0 = B[0+bs*0];
852
853 c_0 = beta0 * C[0+bs*0] + a_0 * b_0;
854
855 D[0+bs*0] = c_0;
856
857 B += 4;
858 C += 4;
859 D += 4;
860
861 }
862
863 }
864
865
866