blob: c6be38f3737f3bf4f7dfe0f455ebb67f10172b30 [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31
32#include "../include/blasfeo_common.h"
33#include "../include/blasfeo_s_kernel.h"
34#include "../include/blasfeo_s_aux.h"
35
36
37
38/****************************
39* old interface
40****************************/
41
42void sgemm_nt_lib(int m, int n, int k, float alpha, float *pA, int sda, float *pB, int sdb, float beta, float *pC, int sdc, float *pD, int sdd)
43 {
44
45 if(m<=0 || n<=0)
46 return;
47
48 const int bs = 4;
49
50 int i, j, l;
51
52 i = 0;
53
54#if defined(TARGET_ARMV8A_ARM_CORTEX_A57)
55 for(; i<m-15; i+=16)
56 {
57 j = 0;
58 for(; j<n-3; j+=4)
59 {
60 kernel_sgemm_nt_16x4_lib4(k, &alpha, &pA[i*sda], sda, &pB[j*sdb], &beta, &pC[j*bs+i*sdc], sdc, &pD[j*bs+i*sdd], sdd);
61 }
62 if(j<n)
63 {
64 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+0)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+0)*sdc], &pD[j*bs+(i+0)*sdd], m-(i+0), n-j);
65 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+4)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+4)*sdc], &pD[j*bs+(i+4)*sdd], m-(i+4), n-j);
66 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+8)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+8)*sdc], &pD[j*bs+(i+8)*sdd], m-(i+8), n-j);
67 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+12)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+12)*sdc], &pD[j*bs+(i+12)*sdd], m-(i+12), n-j);
68 }
69 }
70#endif
71#if defined(TARGET_ARMV7A_ARM_CORTEX_A15) | defined(TARGET_ARMV8A_ARM_CORTEX_A57)
72 for(; i<m-11; i+=12)
73 {
74 j = 0;
75 for(; j<n-3; j+=4)
76 {
77 kernel_sgemm_nt_12x4_lib4(k, &alpha, &pA[i*sda], sda, &pB[j*sdb], &beta, &pC[j*bs+i*sdc], sdc, &pD[j*bs+i*sdd], sdd);
78 }
79 if(j<n)
80 {
81 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+0)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+0)*sdc], &pD[j*bs+(i+0)*sdd], m-(i+0), n-j);
82 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+4)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+4)*sdc], &pD[j*bs+(i+4)*sdd], m-(i+4), n-j);
83 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+8)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+8)*sdc], &pD[j*bs+(i+8)*sdd], m-(i+8), n-j);
84 }
85 }
86#endif
87#if defined(TARGET_ARMV8A_ARM_CORTEX_A57) | defined(TARGET_ARMV7A_ARM_CORTEX_A15)
88 for(; i<m-7; i+=8)
89 {
90 j = 0;
91#if defined(TARGET_ARMV8A_ARM_CORTEX_A57)
92 for(; j<n-7; j+=8)
93 {
94 kernel_sgemm_nt_8x8_lib4(k, &alpha, &pA[i*sda], sda, &pB[j*sdb], sdb, &beta, &pC[j*bs+i*sdc], sdc, &pD[j*bs+i*sdd], sdd);
95 }
96#endif
97 for(; j<n-3; j+=4)
98 {
99 kernel_sgemm_nt_8x4_lib4(k, &alpha, &pA[i*sda], sda, &pB[j*sdb], &beta, &pC[j*bs+i*sdc], sdc, &pD[j*bs+i*sdd], sdd);
100 }
101 if(j<n)
102 {
103 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+0)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+0)*sdc], &pD[j*bs+(i+0)*sdd], m-(i+0), n-j);
104 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+4)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+4)*sdc], &pD[j*bs+(i+4)*sdd], m-(i+4), n-j);
105 }
106 }
107#endif
108 for(; i<m-3; i+=4)
109 {
110 j = 0;
111 for(; j<n-3; j+=4)
112 {
113 kernel_sgemm_nt_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
114 }
115 if(j<n)
116 {
117 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
118 }
119 }
120 if(m>i)
121 {
122 goto left_4;
123 }
124
125 // common return if i==m
126 return;
127
128 // clean up loops definitions
129
130 left_12:
131 j = 0;
132 for(; j<n; j+=4)
133 {
134 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+0)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+0)*sdc], &pD[j*bs+(i+0)*sdd], m-(i+0), n-j);
135 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+4)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+4)*sdc], &pD[j*bs+(i+4)*sdd], m-(i+4), n-j);
136 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+8)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+8)*sdc], &pD[j*bs+(i+8)*sdd], m-(i+8), n-j);
137 }
138 return;
139
140 left_8:
141 j = 0;
142 for(; j<n; j+=4)
143 {
144 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+0)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+0)*sdc], &pD[j*bs+(i+0)*sdd], m-(i+0), n-j);
145 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[(i+4)*sda], &pB[j*sdb], &beta, &pC[j*bs+(i+4)*sdc], &pD[j*bs+(i+4)*sdd], m-(i+4), n-j);
146 }
147 return;
148
149 left_4:
150 j = 0;
151 for(; j<n; j+=4)
152 {
153 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
154 }
155 return;
156
157 }
158
159
160
161void sgemm_nn_lib(int m, int n, int k, float alpha, float *pA, int sda, float *pB, int sdb, float beta, float *pC, int sdc, float *pD, int sdd)
162 {
163
164 if(m<=0 || n<=0)
165 return;
166
167 const int bs = 4;
168
169 int i, j, l;
170
171 i = 0;
172
173 for(; i<m-3; i+=4)
174 {
175 j = 0;
176 for(; j<n-3; j+=4)
177 {
178 kernel_sgemm_nn_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*bs], sdb, &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
179 }
180 if(j<n)
181 {
182 kernel_sgemm_nn_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*bs], sdb, &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
183 }
184 }
185 if(m>i)
186 {
187 goto left_4;
188 }
189
190 // common return if i==m
191 return;
192
193 // clean up loops definitions
194
195 left_4:
196 j = 0;
197 for(; j<n; j+=4)
198 {
199 kernel_sgemm_nn_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*bs], sdb, &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
200 }
201 return;
202
203 }
204
205
206
207void strmm_nt_ru_lib(int m, int n, float alpha, float *pA, int sda, float *pB, int sdb, float beta, float *pC, int sdc, float *pD, int sdd)
208 {
209
210 if(m<=0 || n<=0)
211 return;
212
213 const int bs = 4;
214
215 int i, j;
216
217 i = 0;
218 for(; i<m-3; i+=4)
219 {
220 j = 0;
221 for(; j<n-3; j+=4)
222 {
223 kernel_strmm_nt_ru_4x4_lib4(n-j, &alpha, &pA[j*bs+i*sda], &pB[j*bs+j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
224 }
225 if(j<n) // TODO specialized edge routine
226 {
227 kernel_strmm_nt_ru_4x4_vs_lib4(n-j, &alpha, &pA[j*bs+i*sda], &pB[j*bs+j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
228 }
229 }
230 if(i<m)
231 {
232 goto left_4;
233 }
234
235 // common return
236 return;
237
238 left_4:
239 j = 0;
240// for(; j<n-3; j+=4)
241 for(; j<n; j+=4)
242 {
243 kernel_strmm_nt_ru_4x4_vs_lib4(n-j, &alpha, &pA[j*bs+i*sda], &pB[j*bs+j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
244 }
245// if(j<n) // TODO specialized edge routine
246// {
247// kernel_strmm_nt_ru_4x4_vs_lib4(n-j, &pA[j*bs+i*sda], &pB[j*bs+j*sdb], alg, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
248// }
249 return;
250
251 }
252
253
254
255// D <= B * A^{-T} , with A lower triangular with unit diagonal
256void strsm_nt_rl_one_lib(int m, int n, float *pA, int sda, float *pB, int sdb, float *pD, int sdd)
257 {
258
259 if(m<=0 || n<=0)
260 return;
261
262 const int bs = 4;
263
264 int i, j;
265
266 i = 0;
267
268 for(; i<m-3; i+=4)
269 {
270 j = 0;
271 for(; j<n-3; j+=4)
272 {
273 kernel_strsm_nt_rl_one_4x4_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda]);
274 }
275 if(j<n)
276 {
277 kernel_strsm_nt_rl_one_4x4_vs_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda], m-i, n-j);
278 }
279 }
280 if(m>i)
281 {
282 goto left_4;
283 }
284
285 // common return if i==m
286 return;
287
288 left_4:
289 j = 0;
290 for(; j<n; j+=4)
291 {
292 kernel_strsm_nt_rl_one_4x4_vs_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda], m-i, n-j);
293 }
294 return;
295
296 }
297
298
299
300// D <= B * A^{-T} , with A upper triangular employing explicit inverse of diagonal
301void strsm_nt_ru_inv_lib(int m, int n, float *pA, int sda, float *inv_diag_A, float *pB, int sdb, float *pD, int sdd)
302 {
303
304 if(m<=0 || n<=0)
305 return;
306
307 const int bs = 4;
308
309 int i, j, idx;
310
311 int rn = n%4;
312
313 float *dummy;
314
315 i = 0;
316
317 for(; i<m-3; i+=4)
318 {
319 j = 0;
320 // clean at the end
321 if(rn>0)
322 {
323 idx = n-rn;
324 kernel_strsm_nt_ru_inv_4x4_vs_lib4(0, dummy, dummy, &pB[i*sdb+idx*bs], &pD[i*sdd+idx*bs], &pA[idx*sda+idx*bs], &inv_diag_A[idx], m-i, rn);
325 j += rn;
326 }
327 for(; j<n; j+=4)
328 {
329 idx = n-j-4;
330 kernel_strsm_nt_ru_inv_4x4_lib4(j, &pD[i*sdd+(idx+4)*bs], &pA[idx*sda+(idx+4)*bs], &pB[i*sdb+idx*bs], &pD[i*sdd+idx*bs], &pA[idx*sda+idx*bs], &inv_diag_A[idx]);
331 }
332 }
333 if(m>i)
334 {
335 goto left_4;
336 }
337
338 // common return if i==m
339 return;
340
341 left_4:
342 j = 0;
343 // TODO
344 // clean at the end
345 if(rn>0)
346 {
347 idx = n-rn;
348 kernel_strsm_nt_ru_inv_4x4_vs_lib4(0, dummy, dummy, &pB[i*sdb+idx*bs], &pD[i*sdd+idx*bs], &pA[idx*sda+idx*bs], &inv_diag_A[idx], m-i, rn);
349 j += rn;
350 }
351 for(; j<n; j+=4)
352 {
353 idx = n-j-4;
354 kernel_strsm_nt_ru_inv_4x4_vs_lib4(j, &pD[i*sdd+(idx+4)*bs], &pA[idx*sda+(idx+4)*bs], &pB[i*sdb+idx*bs], &pD[i*sdd+idx*bs], &pA[idx*sda+idx*bs], &inv_diag_A[idx], m-i, 4);
355 }
356 return;
357
358 }
359
360
361
362// D <= A^{-1} * B , with A lower triangular with unit diagonal
363void strsm_nn_ll_one_lib(int m, int n, float *pA, int sda, float *pB, int sdb, float *pD, int sdd)
364 {
365
366 if(m<=0 || n<=0)
367 return;
368
369 const int bs = 4;
370
371 int i, j;
372
373 i = 0;
374
375 for( ; i<m-3; i+=4)
376 {
377 j = 0;
378 for( ; j<n-3; j+=4)
379 {
380 kernel_strsm_nn_ll_one_4x4_lib4(i, pA+i*sda, pD+j*bs, sdd, pB+i*sdb+j*bs, pD+i*sdd+j*bs, pA+i*sda+i*bs);
381 }
382 if(j<n)
383 {
384 kernel_strsm_nn_ll_one_4x4_vs_lib4(i, pA+i*sda, pD+j*bs, sdd, pB+i*sdb+j*bs, pD+i*sdd+j*bs, pA+i*sda+i*bs, m-i, n-j);
385 }
386 }
387 if(i<m)
388 {
389 goto left_4;
390 }
391
392 // common return
393 return;
394
395 left_4:
396 j = 0;
397 for( ; j<n; j+=4)
398 {
399 kernel_strsm_nn_ll_one_4x4_vs_lib4(i, pA+i*sda, pD+j*bs, sdd, pB+i*sdb+j*bs, pD+i*sdd+j*bs, pA+i*sda+i*bs, m-i, n-j);
400 }
401 return;
402
403 }
404
405
406
407// D <= A^{-1} * B , with A upper triangular employing explicit inverse of diagonal
408void strsm_nn_lu_inv_lib(int m, int n, float *pA, int sda, float *inv_diag_A, float *pB, int sdb, float *pD, int sdd)
409 {
410
411 if(m<=0 || n<=0)
412 return;
413
414 const int bs = 4;
415
416 int i, j, idx;
417 float *dummy;
418
419 i = 0;
420 int rm = m%4;
421 if(rm>0)
422 {
423 // TODO code expliticly the final case
424 idx = m-rm; // position of the part to do
425 j = 0;
426 for( ; j<n; j+=4)
427 {
428 kernel_strsm_nn_lu_inv_4x4_vs_lib4(0, dummy, dummy, 0, pB+idx*sdb+j*bs, pD+idx*sdd+j*bs, pA+idx*sda+idx*bs, inv_diag_A+idx, rm, n-j);
429 }
430 // TODO
431 i += rm;
432 }
433// int em = m-rm;
434 for( ; i<m; i+=4)
435 {
436 idx = m-i; // position of already done part
437 j = 0;
438 for( ; j<n-3; j+=4)
439 {
440 kernel_strsm_nn_lu_inv_4x4_lib4(i, pA+(idx-4)*sda+idx*bs, pD+idx*sdd+j*bs, sdd, pB+(idx-4)*sdb+j*bs, pD+(idx-4)*sdd+j*bs, pA+(idx-4)*sda+(idx-4)*bs, inv_diag_A+(idx-4));
441 }
442 if(j<n)
443 {
444 kernel_strsm_nn_lu_inv_4x4_vs_lib4(i, pA+(idx-4)*sda+idx*bs, pD+idx*sdd+j*bs, sdd, pB+(idx-4)*sdb+j*bs, pD+(idx-4)*sdd+j*bs, pA+(idx-4)*sda+(idx-4)*bs, inv_diag_A+(idx-4), 4, n-j);
445 }
446 }
447
448 // common return
449 return;
450
451 }
452
453
454
455/****************************
456* new interface
457****************************/
458
459
460
461#if defined(LA_HIGH_PERFORMANCE)
462
463
464
465// dgemm nt
466void sgemm_nt_libstr(int m, int n, int k, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
467 {
468
469 if(m<=0 | n<=0)
470 return;
471
472 const int bs = 4;
473
474 int sda = sA->cn;
475 int sdb = sB->cn;
476 int sdc = sC->cn;
477 int sdd = sD->cn;
478 float *pA = sA->pA + aj*bs;
479 float *pB = sB->pA + bj*bs;
480 float *pC = sC->pA + cj*bs;
481 float *pD = sD->pA + dj*bs;
482
483 if(ai==0 & bi==0 & ci==0 & di==0)
484 {
485 sgemm_nt_lib(m, n, k, alpha, pA, sda, pB, sdb, beta, pC, sdc, pD, sdd);
486 return;
487 }
488
489 pA += ai/bs*bs*sda;
490 pB += bi/bs*bs*sda;
491 int ci0 = ci-ai%bs;
492 int di0 = di-ai%bs;
493 int offsetC;
494 int offsetD;
495 if(ci0>=0)
496 {
497 pC += ci0/bs*bs*sdd;
498 offsetC = ci0%bs;
499 }
500 else
501 {
502 pC += -4*sdc;
503 offsetC = bs+ci0;
504 }
505 if(di0>=0)
506 {
507 pD += di0/bs*bs*sdd;
508 offsetD = di0%bs;
509 }
510 else
511 {
512 pD += -4*sdd;
513 offsetD = bs+di0;
514 }
515
516 int i, j, l;
517
518 int idxB;
519
520 i = 0;
521 // clean up at the beginning
522 if(ai%bs!=0)
523 {
524 j = 0;
525 idxB = 0;
526 // clean up at the beginning
527 if(bi%bs!=0)
528 {
529 kernel_sgemm_nt_4x4_gen_lib4(k, &alpha, &pA[i*sda], &pB[idxB*sdb], &beta, offsetC, &pC[j*bs+i*sdc]-bi%bs*bs, sdc, offsetD, &pD[j*bs+i*sdd]-bi%bs*bs, sdd, ai%bs, m-i, bi%bs, n-j);
530 j += bs-bi%bs;
531 idxB += 4;
532 }
533 // main loop
534 for(; j<n; j+=4)
535 {
536 kernel_sgemm_nt_4x4_gen_lib4(k, &alpha, &pA[i*sda], &pB[idxB*sdb], &beta, offsetC, &pC[j*bs+i*sdc], sdc, offsetD, &pD[j*bs+i*sdd], sdd, ai%bs, m-i, 0, n-j);
537 idxB += 4;
538 }
539 m -= bs-ai%bs;
540 pA += bs*sda;
541 pC += bs*sdc;
542 pD += bs*sdd;
543 }
544 // main loop
545 for(; i<m; i+=4)
546 {
547 j = 0;
548 idxB = 0;
549 // clean up at the beginning
550 if(bi%bs!=0)
551 {
552 kernel_sgemm_nt_4x4_gen_lib4(k, &alpha, &pA[i*sda], &pB[idxB*sdb], &beta, offsetC, &pC[j*bs+i*sdc]-bi%bs*bs, sdc, offsetD, &pD[j*bs+i*sdd]-bi%bs*bs, sdd, 0, m-i, bi%bs, n-j);
553 j += bs-bi%bs;
554 idxB += 4;
555 }
556 // main loop
557 for(; j<n; j+=4)
558 {
559 kernel_sgemm_nt_4x4_gen_lib4(k, &alpha, &pA[i*sda], &pB[idxB*sdb], &beta, offsetC, &pC[j*bs+i*sdc], sdc, offsetD, &pD[j*bs+i*sdd], sdd, 0, m-i, 0, n-j);
560 idxB += 4;
561 }
562 }
563
564 return;
565
566 }
567
568
569
570// dgemm nn
571void sgemm_nn_libstr(int m, int n, int k, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
572 {
573 if(m<=0 || n<=0)
574 return;
575 if(ai!=0 | bi!=0 | ci!=0 | di!=0)
576 {
577 printf("\nsgemm_nn_libstr: feature not implemented yet: ai=%d, bi=%d, ci=%d, di=%d\n", ai, bi, ci, di);
578 exit(1);
579 }
580 const int bs = 4;
581 int sda = sA->cn;
582 int sdb = sB->cn;
583 int sdc = sC->cn;
584 int sdd = sD->cn;
585 float *pA = sA->pA + aj*bs;
586 float *pB = sB->pA + bj*bs;
587 float *pC = sC->pA + cj*bs;
588 float *pD = sD->pA + dj*bs;
589 sgemm_nn_lib(m, n, k, alpha, pA, sda, pB, sdb, beta, pC, sdc, pD, sdd);
590 return;
591 }
592
593
594
595// dtrsm_nn_llu
596void strsm_llnu_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sD, int di, int dj)
597 {
598 if(ai!=0 | bi!=0 | di!=0 | alpha!=1.0)
599 {
600 printf("\nstrsm_llnu_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d, alpha=%f\n", ai, bi, di, alpha);
601 exit(1);
602 }
603 const int bs = 4;
604 // TODO alpha
605 int sda = sA->cn;
606 int sdb = sB->cn;
607 int sdd = sD->cn;
608 float *pA = sA->pA + aj*bs;
609 float *pB = sB->pA + bj*bs;
610 float *pD = sD->pA + dj*bs;
611 strsm_nn_ll_one_lib(m, n, pA, sda, pB, sdb, pD, sdd);
612 return;
613 }
614
615
616
617// dtrsm_nn_lun
618void strsm_lunn_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sD, int di, int dj)
619 {
620 if(ai!=0 | bi!=0 | di!=0 | alpha!=1.0)
621 {
622 printf("\nstrsm_lunn_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d, alpha=%f\n", ai, bi, di, alpha);
623 exit(1);
624 }
625 const int bs = 4;
626 // TODO alpha
627 int sda = sA->cn;
628 int sdb = sB->cn;
629 int sdd = sD->cn;
630 float *pA = sA->pA + aj*bs;
631 float *pB = sB->pA + bj*bs;
632 float *pD = sD->pA + dj*bs;
633 float *dA = sA->dA;
634 int ii;
635 if(ai==0 & aj==0)
636 {
637 if(sA->use_dA!=1)
638 {
639 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
640 for(ii=0; ii<n; ii++)
641 dA[ii] = 1.0 / dA[ii];
642 sA->use_dA = 1;
643 }
644 }
645 else
646 {
647 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
648 for(ii=0; ii<n; ii++)
649 dA[ii] = 1.0 / dA[ii];
650 sA->use_dA = 0;
651 }
652 strsm_nn_lu_inv_lib(m, n, pA, sda, dA, pB, sdb, pD, sdd);
653 return;
654 }
655
656
657
658// dtrsm_right_lower_transposed_notunit
659void strsm_rltn_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sD, int di, int dj)
660 {
661
662 if(ai!=0 | bi!=0 | di!=0 | alpha!=1.0)
663 {
664 printf("\nstrsm_rltn_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d, alpha=%f\n", ai, bi, di, alpha);
665 exit(1);
666 }
667
668 const int bs = 4;
669
670 // TODO alpha
671
672 int sda = sA->cn;
673 int sdb = sB->cn;
674 int sdd = sD->cn;
675 float *pA = sA->pA + aj*bs;
676 float *pB = sB->pA + bj*bs;
677 float *pD = sD->pA + dj*bs;
678 float *dA = sA->dA;
679
680 int i, j;
681
682 if(ai==0 & aj==0)
683 {
684 if(sA->use_dA!=1)
685 {
686 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
687 for(i=0; i<n; i++)
688 dA[i] = 1.0 / dA[i];
689 sA->use_dA = 1;
690 }
691 }
692 else
693 {
694 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
695 for(i=0; i<n; i++)
696 dA[i] = 1.0 / dA[i];
697 sA->use_dA = 0;
698 }
699
700 if(m<=0 || n<=0)
701 return;
702
703 i = 0;
704
705 for(; i<m-3; i+=4)
706 {
707 j = 0;
708 for(; j<n-3; j+=4)
709 {
710 kernel_strsm_nt_rl_inv_4x4_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda], &dA[j]);
711 }
712 if(j<n)
713 {
714 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda], &dA[j], m-i, n-j);
715 }
716 }
717 if(m>i)
718 {
719 goto left_4;
720 }
721
722 // common return if i==m
723 return;
724
725 left_4:
726 j = 0;
727 for(; j<n; j+=4)
728 {
729 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pA[j*sda], &pB[j*bs+i*sdb], &pD[j*bs+i*sdd], &pA[j*bs+j*sda], &dA[j], m-i, n-j);
730 }
731 return;
732
733 }
734
735
736
737// dtrsm_right_lower_transposed_unit
738void strsm_rltu_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sD, int di, int dj)
739 {
740 if(ai!=0 | bi!=0 | di!=0 | alpha!=1.0)
741 {
742 printf("\nstrsm_rltu_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d, alpha=%f\n", ai, bi, di, alpha);
743 exit(1);
744 }
745 const int bs = 4;
746 // TODO alpha
747 int sda = sA->cn;
748 int sdb = sB->cn;
749 int sdd = sD->cn;
750 float *pA = sA->pA + aj*bs;
751 float *pB = sB->pA + bj*bs;
752 float *pD = sD->pA + dj*bs;
753 strsm_nt_rl_one_lib(m, n, pA, sda, pB, sdb, pD, sdd);
754 return;
755 }
756
757
758
759// dtrsm_right_upper_transposed_notunit
760void strsm_rutn_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sD, int di, int dj)
761 {
762 if(ai!=0 | bi!=0 | di!=0 | alpha!=1.0)
763 {
764 printf("\nstrsm_rutn_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d, alpha=%f\n", ai, bi, di, alpha);
765 exit(1);
766 }
767 const int bs = 4;
768 // TODO alpha
769 int sda = sA->cn;
770 int sdb = sB->cn;
771 int sdd = sD->cn;
772 float *pA = sA->pA + aj*bs;
773 float *pB = sB->pA + bj*bs;
774 float *pD = sD->pA + dj*bs;
775 float *dA = sA->dA;
776 int ii;
777 if(ai==0 & aj==0)
778 {
779 if(sA->use_dA!=1)
780 {
781 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
782 for(ii=0; ii<n; ii++)
783 dA[ii] = 1.0 / dA[ii];
784 sA->use_dA = 1;
785 }
786 }
787 else
788 {
789 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
790 for(ii=0; ii<n; ii++)
791 dA[ii] = 1.0 / dA[ii];
792 sA->use_dA = 0;
793 }
794 strsm_nt_ru_inv_lib(m, n, pA, sda, dA, pB, sdb, pD, sdd);
795 return;
796 }
797
798
799
800// dtrmm_right_upper_transposed_notunit (B, i.e. the first matrix, is triangular !!!)
801void strmm_rutn_libstr(int m, int n, float alpha, struct s_strmat *sB, int bi, int bj, struct s_strmat *sA, int ai, int aj, struct s_strmat *sD, int di, int dj)
802 {
803 if(ai!=0 | bi!=0 | di!=0)
804 {
805 printf("\nstrmm_rutn_libstr: feature not implemented yet: ai=%d, bi=%d, di=%d\n", ai, bi, di);
806 exit(1);
807 }
808 const int bs = 4;
809 int sda = sA->cn;
810 int sdb = sB->cn;
811 int sdd = sD->cn;
812 float *pA = sA->pA + aj*bs;
813 float *pB = sB->pA + bj*bs;
814 float *pD = sD->pA + dj*bs;
815 strmm_nt_ru_lib(m, n, alpha, pA, sda, pB, sdb, 0.0, pD, sdd, pD, sdd);
816 return;
817 }
818
819
820
821// dtrmm_right_lower_nottransposed_notunit (B, i.e. the first matrix, is triangular !!!)
822void strmm_rlnn_libstr(int m, int n, float alpha, struct s_strmat *sB, int bi, int bj, struct s_strmat *sA, int ai, int aj, struct s_strmat *sD, int di, int dj)
823 {
824
825 const int bs = 4;
826
827 int sda = sA->cn;
828 int sdb = sB->cn;
829 int sdd = sD->cn;
830 float *pA = sA->pA + aj*bs;
831 float *pB = sB->pA + bj*bs;
832 float *pD = sD->pA + dj*bs;
833
834 pA += ai/bs*bs*sda;
835 pB += bi/bs*bs*sdb;
836 int offsetB = bi%bs;
837 int di0 = di-ai%bs;
838 int offsetD;
839 if(di0>=0)
840 {
841 pD += di0/bs*bs*sdd;
842 offsetD = di0%bs;
843 }
844 else
845 {
846 pD += -4*sdd;
847 offsetD = bs+di0;
848 }
849
850 int ii, jj;
851
852 ii = 0;
853 if(ai%bs!=0)
854 {
855 jj = 0;
856 for(; jj<n; jj+=4)
857 {
858 kernel_strmm_nn_rl_4x4_gen_lib4(n-jj, &alpha, &pA[ii*sda+jj*bs], offsetB, &pB[jj*sdb+jj*bs], sdb, offsetD, &pD[ii*sdd+jj*bs], sdd, ai%bs, m-ii, 0, n-jj);
859 }
860 m -= bs-ai%bs;
861 pA += bs*sda;
862 pD += bs*sdd;
863 }
864 if(offsetD==0)
865 {
866 for(; ii<m-3; ii+=4)
867 {
868 jj = 0;
869 for(; jj<n-5; jj+=4)
870 {
871 kernel_strmm_nn_rl_4x4_lib4(n-jj, &alpha, &pA[ii*sda+jj*bs], offsetB, &pB[jj*sdb+jj*bs], sdb, &pD[ii*sdd+jj*bs]);
872 }
873 for(; jj<n; jj+=4)
874 {
875 kernel_strmm_nn_rl_4x4_gen_lib4(n-jj, &alpha, &pA[ii*sda+jj*bs], offsetB, &pB[jj*sdb+jj*bs], sdb, 0, &pD[ii*sdd+jj*bs], sdd, 0, 4, 0, n-jj);
876 }
877 }
878 if(ii<m)
879 {
880 goto left_4;
881 }
882 }
883 else
884 {
885 for(; ii<m; ii+=4)
886 {
887 jj = 0;
888 for(; jj<n; jj+=4)
889 {
890 kernel_strmm_nn_rl_4x4_gen_lib4(n-jj, &alpha, &pA[ii*sda+jj*bs], offsetB, &pB[jj*sdb+jj*bs], sdb, offsetD, &pD[ii*sdd+jj*bs], sdd, 0, m-ii, 0, n-jj);
891 }
892 }
893 }
894
895 // common return if i==m
896 return;
897
898 // clean up loops definitions
899
900 left_4:
901 jj = 0;
902 for(; jj<n; jj+=4)
903 {
904 kernel_strmm_nn_rl_4x4_gen_lib4(n-jj, &alpha, &pA[ii*sda+jj*bs], offsetB, &pB[jj*sdb+jj*bs], sdb, offsetD, &pD[ii*sdd+jj*bs], sdd, 0, m-ii, 0, n-jj);
905 }
906 return;
907
908 }
909
910
911
912void ssyrk_ln_libstr(int m, int k, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
913 {
914
915 if(m<=0)
916 return;
917
918 if(ai!=0 | bi!=0 | ci!=0 | di!=0)
919 {
920 printf("\nsryrk_ln_libstr: feature not implemented yet: ai=%d, bi=%d, ci=%d, di=%d\n", ai, bi, ci, di);
921 exit(1);
922 }
923
924 const int bs = 4;
925
926 int sda = sA->cn;
927 int sdb = sB->cn;
928 int sdc = sC->cn;
929 int sdd = sD->cn;
930 float *pA = sA->pA + aj*bs;
931 float *pB = sB->pA + bj*bs;
932 float *pC = sC->pA + cj*bs;
933 float *pD = sD->pA + dj*bs;
934
935// ssyrk_nt_l_lib(m, n, k, alpha, pA, sda, pB, sdb, beta, pC, sdc, pD, sdd);
936
937 int i, j, l;
938
939 i = 0;
940
941 for(; i<m-3; i+=4)
942 {
943 j = 0;
944 for(; j<i; j+=4)
945 {
946 kernel_sgemm_nt_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
947 }
948 kernel_ssyrk_nt_l_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
949 }
950 if(m>i)
951 {
952 goto left_4;
953 }
954
955 // common return if i==m
956 return;
957
958 // clean up loops definitions
959
960 left_4:
961 j = 0;
962 for(; j<i; j+=4)
963 {
964 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, m-j);
965 }
966 kernel_ssyrk_nt_l_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, m-j);
967 return;
968
969 }
970
971
972
973void ssyrk_ln_mn_libstr(int m, int n, int k, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
974 {
975
976 if(m<=0 || n<=0)
977 return;
978
979 if(ai!=0 | bi!=0 | ci!=0 | di!=0)
980 {
981 printf("\nsryrk_ln_libstr: feature not implemented yet: ai=%d, bi=%d, ci=%d, di=%d\n", ai, bi, ci, di);
982 exit(1);
983 }
984
985 const int bs = 4;
986
987 int sda = sA->cn;
988 int sdb = sB->cn;
989 int sdc = sC->cn;
990 int sdd = sD->cn;
991 float *pA = sA->pA + aj*bs;
992 float *pB = sB->pA + bj*bs;
993 float *pC = sC->pA + cj*bs;
994 float *pD = sD->pA + dj*bs;
995
996// ssyrk_nt_l_lib(m, n, k, alpha, pA, sda, pB, sdb, beta, pC, sdc, pD, sdd);
997
998 int i, j, l;
999
1000 i = 0;
1001
1002 for(; i<m-3; i+=4)
1003 {
1004 j = 0;
1005 for(; j<i && j<n-3; j+=4)
1006 {
1007 kernel_sgemm_nt_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
1008 }
1009 if(j<n)
1010 {
1011 if(i<j) // dgemm
1012 {
1013 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
1014 }
1015 else // dsyrk
1016 {
1017 if(j<n-3)
1018 {
1019 kernel_ssyrk_nt_l_4x4_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd]);
1020 }
1021 else
1022 {
1023 kernel_ssyrk_nt_l_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
1024 }
1025 }
1026 }
1027 }
1028 if(m>i)
1029 {
1030 goto left_4;
1031 }
1032
1033 // common return if i==m
1034 return;
1035
1036 // clean up loops definitions
1037
1038 left_4:
1039 j = 0;
1040 for(; j<i && j<n; j+=4)
1041 {
1042 kernel_sgemm_nt_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
1043 }
1044 if(j<n)
1045 {
1046 kernel_ssyrk_nt_l_4x4_vs_lib4(k, &alpha, &pA[i*sda], &pB[j*sdb], &beta, &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], m-i, n-j);
1047 }
1048 return;
1049
1050 }
1051
1052
1053
1054#else
1055
1056#error : wrong LA choice
1057
1058#endif
1059
1060
1061
1062