blob: b7a947d1d49ed5af6dea7a3b490d1d3c3d7f23bc [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31
32#include "../include/blasfeo_common.h"
33#include "../include/blasfeo_s_kernel.h"
34#include "../include/blasfeo_s_aux.h"
35
36
37
38
39#if defined(LA_HIGH_PERFORMANCE)
40
41
42
43void sgemv_n_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, float beta, struct s_strvec *sy, int yi, struct s_strvec *sz, int zi)
44 {
45
46 if(m<0)
47 return;
48
49 const int bs = 4;
50
51 int i;
52
53 int sda = sA->cn;
54 float *pA = sA->pA + aj*bs + ai/bs*bs*sda;
55 float *x = sx->pa + xi;
56 float *y = sy->pa + yi;
57 float *z = sz->pa + zi;
58
59 i = 0;
60 // clean up at the beginning
61 if(ai%bs!=0)
62 {
63 kernel_sgemv_n_4_gen_lib4(n, &alpha, pA, x, &beta, y-ai%bs, z-ai%bs, ai%bs, m+ai%bs);
64 pA += bs*sda;
65 y += 4 - ai%bs;
66 z += 4 - ai%bs;
67 m -= 4 - ai%bs;
68 }
69 // main loop
70 for( ; i<m-3; i+=4)
71 {
72 kernel_sgemv_n_4_lib4(n, &alpha, &pA[i*sda], x, &beta, &y[i], &z[i]);
73 }
74 if(i<m)
75 {
76 kernel_sgemv_n_4_vs_lib4(n, &alpha, &pA[i*sda], x, &beta, &y[i], &z[i], m-i);
77 }
78
79 return;
80
81 }
82
83
84
85void sgemv_t_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, float beta, struct s_strvec *sy, int yi, struct s_strvec *sz, int zi)
86 {
87
88 if(n<=0)
89 return;
90
91 const int bs = 4;
92
93 int i;
94
95 int sda = sA->cn;
96 float *pA = sA->pA + aj*bs + ai/bs*bs*sda + ai%bs;
97 float *x = sx->pa + xi;
98 float *y = sy->pa + yi;
99 float *z = sz->pa + zi;
100
101 if(ai%bs==0)
102 {
103 i = 0;
104 for( ; i<n-3; i+=4)
105 {
106 kernel_sgemv_t_4_lib4(m, &alpha, &pA[i*bs], sda, x, &beta, &y[i], &z[i]);
107 }
108 if(i<n)
109 {
110 kernel_sgemv_t_4_vs_lib4(m, &alpha, &pA[i*bs], sda, x, &beta, &y[i], &z[i], n-i);
111 }
112 }
113 else // TODO kernel 8
114 {
115 i = 0;
116 for( ; i<n; i+=4)
117 {
118 kernel_sgemv_t_4_gen_lib4(m, &alpha, ai%bs, &pA[i*bs], sda, x, &beta, &y[i], &z[i], n-i);
119 }
120 }
121
122 return;
123
124 }
125
126
127
128void sgemv_nt_libstr(int m, int n, float alpha_n, float alpha_t, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx_n, int xi_n, struct s_strvec *sx_t, int xi_t, float beta_n, float beta_t, struct s_strvec *sy_n, int yi_n, struct s_strvec *sy_t, int yi_t, struct s_strvec *sz_n, int zi_n, struct s_strvec *sz_t, int zi_t)
129 {
130
131 if(ai!=0)
132 {
133 printf("\nsgemv_nt_libstr: feature not implemented yet: ai=%d\n", ai);
134 exit(1);
135 }
136
137 const int bs = 4;
138
139 int sda = sA->cn;
140 float *pA = sA->pA + aj*bs; // TODO ai
141 float *x_n = sx_n->pa + xi_n;
142 float *x_t = sx_t->pa + xi_t;
143 float *y_n = sy_n->pa + yi_n;
144 float *y_t = sy_t->pa + yi_t;
145 float *z_n = sz_n->pa + zi_n;
146 float *z_t = sz_t->pa + zi_t;
147
148// sgemv_nt_lib(m, n, alpha_n, alpha_t, pA, sda, x_n, x_t, beta_n, beta_t, y_n, y_t, z_n, z_t);
149
150// if(m<=0 | n<=0)
151// return;
152
153 int ii;
154
155 // copy and scale y_n int z_n
156 ii = 0;
157 for(; ii<m-3; ii+=4)
158 {
159 z_n[ii+0] = beta_n*y_n[ii+0];
160 z_n[ii+1] = beta_n*y_n[ii+1];
161 z_n[ii+2] = beta_n*y_n[ii+2];
162 z_n[ii+3] = beta_n*y_n[ii+3];
163 }
164 for(; ii<m; ii++)
165 {
166 z_n[ii+0] = beta_n*y_n[ii+0];
167 }
168
169 ii = 0;
170 for(; ii<n-3; ii+=4)
171 {
172 kernel_sgemv_nt_4_lib4(m, &alpha_n, &alpha_t, pA+ii*bs, sda, x_n+ii, x_t, &beta_t, y_t+ii, z_n, z_t+ii);
173 }
174 if(ii<n)
175 {
176 kernel_sgemv_nt_4_vs_lib4(m, &alpha_n, &alpha_t, pA+ii*bs, sda, x_n+ii, x_t, &beta_t, y_t+ii, z_n, z_t+ii, n-ii);
177 }
178
179 return;
180 }
181
182
183
184void ssymv_l_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, float beta, struct s_strvec *sy, int yi, struct s_strvec *sz, int zi)
185 {
186
187 if(m<=0 | n<=0)
188 return;
189
190 const int bs = 4;
191
192 int ii, n1;
193
194 int sda = sA->cn;
195 float *pA = sA->pA + aj*bs + ai/bs*bs*sda + ai%bs;
196 float *x = sx->pa + xi;
197 float *y = sy->pa + yi;
198 float *z = sz->pa + zi;
199
200 // copy and scale y int z
201 ii = 0;
202 for(; ii<m-3; ii+=4)
203 {
204 z[ii+0] = beta*y[ii+0];
205 z[ii+1] = beta*y[ii+1];
206 z[ii+2] = beta*y[ii+2];
207 z[ii+3] = beta*y[ii+3];
208 }
209 for(; ii<m; ii++)
210 {
211 z[ii+0] = beta*y[ii+0];
212 }
213
214 // clean up at the beginning
215 if(ai%bs!=0) // 1, 2, 3
216 {
217 n1 = 4-ai%bs;
218 kernel_ssymv_l_4_gen_lib4(m, &alpha, ai%bs, &pA[0], sda, &x[0], &z[0], n<n1 ? n : n1);
219 pA += n1 + n1*bs + (sda-1)*bs;
220 x += n1;
221 z += n1;
222 m -= n1;
223 n -= n1;
224 }
225 // main loop
226 ii = 0;
227 for(; ii<n-3; ii+=4)
228 {
229 kernel_ssymv_l_4_lib4(m-ii, &alpha, &pA[ii*bs+ii*sda], sda, &x[ii], &z[ii]);
230 }
231 // clean up at the end
232 if(ii<n)
233 {
234 kernel_ssymv_l_4_gen_lib4(m-ii, &alpha, 0, &pA[ii*bs+ii*sda], sda, &x[ii], &z[ii], n-ii);
235 }
236
237 return;
238 }
239
240
241
242// m >= n
243void strmv_lnn_libstr(int m, int n, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
244 {
245
246 if(m<=0)
247 return;
248
249 const int bs = 4;
250
251 int sda = sA->cn;
252 float *pA = sA->pA + aj*bs + ai/bs*bs*sda + ai%bs;
253 float *x = sx->pa + xi;
254 float *z = sz->pa + zi;
255
256 if(m-n>0)
257 sgemv_n_libstr(m-n, n, 1.0, sA, ai+n, aj, sx, xi, 0.0, sz, zi+n, sz, zi+n);
258
259 float *pA2 = pA;
260 float *z2 = z;
261 int m2 = n;
262 int n2 = 0;
263 float *pA3, *x3;
264
265 float alpha = 1.0;
266 float beta = 1.0;
267
268 float zt[4];
269
270 int ii, jj, jj_end;
271
272 ii = 0;
273
274 if(ai%4!=0)
275 {
276 pA2 += sda*bs - ai%bs;
277 z2 += bs-ai%bs;
278 m2 -= bs-ai%bs;
279 n2 += bs-ai%bs;
280 }
281
282 pA2 += m2/bs*bs*sda;
283 z2 += m2/bs*bs;
284 n2 += m2/bs*bs;
285
286 if(m2%bs!=0)
287 {
288 //
289 pA3 = pA2 + bs*n2;
290 x3 = x + n2;
291 zt[3] = pA3[3+bs*0]*x3[0] + pA3[3+bs*1]*x3[1] + pA3[3+bs*2]*x3[2] + pA3[3+bs*3]*x3[3];
292 zt[2] = pA3[2+bs*0]*x3[0] + pA3[2+bs*1]*x3[1] + pA3[2+bs*2]*x3[2];
293 zt[1] = pA3[1+bs*0]*x3[0] + pA3[1+bs*1]*x3[1];
294 zt[0] = pA3[0+bs*0]*x3[0];
295 kernel_sgemv_n_4_lib4(n2, &alpha, pA2, x, &beta, zt, zt);
296 for(jj=0; jj<m2%bs; jj++)
297 z2[jj] = zt[jj];
298 }
299 for(; ii<m2-3; ii+=4)
300 {
301 pA2 -= bs*sda;
302 z2 -= 4;
303 n2 -= 4;
304 pA3 = pA2 + bs*n2;
305 x3 = x + n2;
306 z2[3] = pA3[3+bs*0]*x3[0] + pA3[3+bs*1]*x3[1] + pA3[3+bs*2]*x3[2] + pA3[3+bs*3]*x3[3];
307 z2[2] = pA3[2+bs*0]*x3[0] + pA3[2+bs*1]*x3[1] + pA3[2+bs*2]*x3[2];
308 z2[1] = pA3[1+bs*0]*x3[0] + pA3[1+bs*1]*x3[1];
309 z2[0] = pA3[0+bs*0]*x3[0];
310 kernel_sgemv_n_4_lib4(n2, &alpha, pA2, x, &beta, z2, z2);
311 }
312 if(ai%4!=0)
313 {
314 if(ai%bs==1)
315 {
316 zt[2] = pA[2+bs*0]*x[0] + pA[2+bs*1]*x[1] + pA[2+bs*2]*x[2];
317 zt[1] = pA[1+bs*0]*x[0] + pA[1+bs*1]*x[1];
318 zt[0] = pA[0+bs*0]*x[0];
319 jj_end = 4-ai%bs<n ? 4-ai%bs : n;
320 for(jj=0; jj<jj_end; jj++)
321 z[jj] = zt[jj];
322 }
323 else if(ai%bs==2)
324 {
325 zt[1] = pA[1+bs*0]*x[0] + pA[1+bs*1]*x[1];
326 zt[0] = pA[0+bs*0]*x[0];
327 jj_end = 4-ai%bs<n ? 4-ai%bs : n;
328 for(jj=0; jj<jj_end; jj++)
329 z[jj] = zt[jj];
330 }
331 else // if (ai%bs==3)
332 {
333 z[0] = pA[0+bs*0]*x[0];
334 }
335 }
336
337 return;
338
339 }
340
341
342
343// m >= n
344void strmv_ltn_libstr(int m, int n, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
345 {
346
347 if(m<=0)
348 return;
349
350 const int bs = 4;
351
352 int sda = sA->cn;
353 float *pA = sA->pA + aj*bs + ai/bs*bs*sda + ai%bs;
354 float *x = sx->pa + xi;
355 float *z = sz->pa + zi;
356
357 float xt[4];
358 float zt[4];
359
360 float alpha = 1.0;
361 float beta = 1.0;
362
363 int ii, jj, ll, ll_max;
364
365 jj = 0;
366
367 if(ai%bs!=0)
368 {
369
370 if(ai%bs==1)
371 {
372 ll_max = m-jj<3 ? m-jj : 3;
373 for(ll=0; ll<ll_max; ll++)
374 xt[ll] = x[ll];
375 for(; ll<3; ll++)
376 xt[ll] = 0.0;
377 zt[0] = pA[0+bs*0]*xt[0] + pA[1+bs*0]*xt[1] + pA[2+bs*0]*xt[2];
378 zt[1] = pA[1+bs*1]*xt[1] + pA[2+bs*1]*xt[2];
379 zt[2] = pA[2+bs*2]*xt[2];
380 pA += bs*sda - 1;
381 x += 3;
382 kernel_sgemv_t_4_lib4(m-3-jj, &alpha, pA, sda, x, &beta, zt, zt);
383 ll_max = n-jj<3 ? n-jj : 3;
384 for(ll=0; ll<ll_max; ll++)
385 z[ll] = zt[ll];
386 pA += bs*3;
387 z += 3;
388 jj += 3;
389 }
390 else if(ai%bs==2)
391 {
392 ll_max = m-jj<2 ? m-jj : 2;
393 for(ll=0; ll<ll_max; ll++)
394 xt[ll] = x[ll];
395 for(; ll<2; ll++)
396 xt[ll] = 0.0;
397 zt[0] = pA[0+bs*0]*xt[0] + pA[1+bs*0]*xt[1];
398 zt[1] = pA[1+bs*1]*xt[1];
399 pA += bs*sda - 2;
400 x += 2;
401 kernel_sgemv_t_4_lib4(m-2-jj, &alpha, pA, sda, x, &beta, zt, zt);
402 ll_max = n-jj<2 ? n-jj : 2;
403 for(ll=0; ll<ll_max; ll++)
404 z[ll] = zt[ll];
405 pA += bs*2;
406 z += 2;
407 jj += 2;
408 }
409 else // if(ai%bs==3)
410 {
411 ll_max = m-jj<1 ? m-jj : 1;
412 for(ll=0; ll<ll_max; ll++)
413 xt[ll] = x[ll];
414 for(; ll<1; ll++)
415 xt[ll] = 0.0;
416 zt[0] = pA[0+bs*0]*xt[0];
417 pA += bs*sda - 3;
418 x += 1;
419 kernel_sgemv_t_4_lib4(m-1-jj, &alpha, pA, sda, x, &beta, zt, zt);
420 ll_max = n-jj<1 ? n-jj : 1;
421 for(ll=0; ll<ll_max; ll++)
422 z[ll] = zt[ll];
423 pA += bs*1;
424 z += 1;
425 jj += 1;
426 }
427
428 }
429
430 for(; jj<n-3; jj+=4)
431 {
432 zt[0] = pA[0+bs*0]*x[0] + pA[1+bs*0]*x[1] + pA[2+bs*0]*x[2] + pA[3+bs*0]*x[3];
433 zt[1] = pA[1+bs*1]*x[1] + pA[2+bs*1]*x[2] + pA[3+bs*1]*x[3];
434 zt[2] = pA[2+bs*2]*x[2] + pA[3+bs*2]*x[3];
435 zt[3] = pA[3+bs*3]*x[3];
436 pA += bs*sda;
437 x += 4;
438 kernel_sgemv_t_4_lib4(m-4-jj, &alpha, pA, sda, x, &beta, zt, z);
439 pA += bs*4;
440 z += 4;
441 }
442 if(jj<n)
443 {
444 ll_max = m-jj<4 ? m-jj : 4;
445 for(ll=0; ll<ll_max; ll++)
446 xt[ll] = x[ll];
447 for(; ll<4; ll++)
448 xt[ll] = 0.0;
449 zt[0] = pA[0+bs*0]*xt[0] + pA[1+bs*0]*xt[1] + pA[2+bs*0]*xt[2] + pA[3+bs*0]*xt[3];
450 zt[1] = pA[1+bs*1]*xt[1] + pA[2+bs*1]*xt[2] + pA[3+bs*1]*xt[3];
451 zt[2] = pA[2+bs*2]*xt[2] + pA[3+bs*2]*xt[3];
452 zt[3] = pA[3+bs*3]*xt[3];
453 pA += bs*sda;
454 x += 4;
455 kernel_sgemv_t_4_lib4(m-4-jj, &alpha, pA, sda, x, &beta, zt, zt);
456 for(ll=0; ll<n-jj; ll++)
457 z[ll] = zt[ll];
458// pA += bs*4;
459// z += 4;
460 }
461
462 return;
463
464 }
465
466
467
468void strmv_unn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
469 {
470
471 if(m<=0)
472 return;
473
474 if(ai!=0)
475 {
476 printf("\ndtrmv_unn_libstr: feature not implemented yet: ai=%d\n", ai);
477 exit(1);
478 }
479
480 const int bs = 4;
481
482 int sda = sA->cn;
483 float *pA = sA->pA + aj*bs; // TODO ai
484 float *x = sx->pa + xi;
485 float *z = sz->pa + zi;
486
487 int i;
488
489 i=0;
490 for(; i<m-3; i+=4)
491 {
492 kernel_strmv_un_4_lib4(m-i, pA, x, z);
493 pA += 4*sda+4*bs;
494 x += 4;
495 z += 4;
496 }
497 if(m>i)
498 {
499 if(m-i==1)
500 {
501 z[0] = pA[0+bs*0]*x[0];
502 }
503 else if(m-i==2)
504 {
505 z[0] = pA[0+bs*0]*x[0] + pA[0+bs*1]*x[1];
506 z[1] = pA[1+bs*1]*x[1];
507 }
508 else // if(m-i==3)
509 {
510 z[0] = pA[0+bs*0]*x[0] + pA[0+bs*1]*x[1] + pA[0+bs*2]*x[2];
511 z[1] = pA[1+bs*1]*x[1] + pA[1+bs*2]*x[2];
512 z[2] = pA[2+bs*2]*x[2];
513 }
514 }
515
516 return;
517
518 }
519
520
521
522void strmv_utn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
523 {
524
525 if(m<=0)
526 return;
527
528 if(ai!=0)
529 {
530 printf("\nstrmv_utn_libstr: feature not implemented yet: ai=%d\n", ai);
531 exit(1);
532 }
533
534 const int bs = 4;
535
536 int sda = sA->cn;
537 float *pA = sA->pA + aj*bs; // TODO ai
538 float *x = sx->pa + xi;
539 float *z = sz->pa + zi;
540
541 int ii, idx;
542
543 float *ptrA;
544
545 ii=0;
546 idx = m/bs*bs;
547 if(m%bs!=0)
548 {
549 kernel_strmv_ut_4_vs_lib4(m, pA+idx*bs, sda, x, z+idx, m%bs);
550 ii += m%bs;
551 }
552 idx -= 4;
553 for(; ii<m; ii+=4)
554 {
555 kernel_strmv_ut_4_lib4(idx+4, pA+idx*bs, sda, x, z+idx);
556 idx -= 4;
557 }
558
559 return;
560
561 }
562
563
564
565void strsv_lnn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
566 {
567
568 if(m==0)
569 return;
570
571#if defined(DIM_CHECK)
572 // non-negative size
573 if(m<0) printf("\n****** strsv_lnn_libstr : m<0 : %d<0 *****\n", m);
574 // non-negative offset
575 if(ai<0) printf("\n****** strsv_lnn_libstr : ai<0 : %d<0 *****\n", ai);
576 if(aj<0) printf("\n****** strsv_lnn_libstr : aj<0 : %d<0 *****\n", aj);
577 if(xi<0) printf("\n****** strsv_lnn_libstr : xi<0 : %d<0 *****\n", xi);
578 if(zi<0) printf("\n****** strsv_lnn_libstr : zi<0 : %d<0 *****\n", zi);
579 // inside matrix
580 // A: m x k
581 if(ai+m > sA->m) printf("\n***** strsv_lnn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
582 if(aj+m > sA->n) printf("\n***** strsv_lnn_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
583 // x: m
584 if(xi+m > sx->m) printf("\n***** strsv_lnn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
585 // z: m
586 if(zi+m > sz->m) printf("\n***** strsv_lnn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
587#endif
588
589 if(ai!=0)
590 {
591 printf("\nstrsv_lnn_libstr: feature not implemented yet: ai=%d\n", ai);
592 exit(1);
593 }
594
595 const int bs = 4;
596
597 int sda = sA->cn;
598 float *pA = sA->pA + aj*bs; // TODO ai
599 float *dA = sA->dA;
600 float *x = sx->pa + xi;
601 float *z = sz->pa + zi;
602
603 int ii;
604
605 if(ai==0 & aj==0)
606 {
607 if(sA->use_dA!=1)
608 {
609 sdiaex_lib(m, 1.0, ai, pA, sda, dA);
610 for(ii=0; ii<m; ii++)
611 dA[ii] = 1.0 / dA[ii];
612 sA->use_dA = 1;
613 }
614 }
615 else
616 {
617 sdiaex_lib(m, 1.0, ai, pA, sda, dA);
618 for(ii=0; ii<m; ii++)
619 dA[ii] = 1.0 / dA[ii];
620 sA->use_dA = 0;
621 }
622
623 int i;
624
625 if(x!=z)
626 {
627 for(i=0; i<m; i++)
628 z[i] = x[i];
629 }
630
631 i = 0;
632 for( ; i<m-3; i+=4)
633 {
634 kernel_strsv_ln_inv_4_lib4(i, &pA[i*sda], &dA[i], z, &z[i], &z[i]);
635 }
636 if(i<m)
637 {
638 kernel_strsv_ln_inv_4_vs_lib4(i, &pA[i*sda], &dA[i], z, &z[i], &z[i], m-i, m-i);
639 i+=4;
640 }
641
642 return;
643
644 }
645
646
647
648void strsv_lnn_mn_libstr(int m, int n, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
649 {
650
651 if(m==0 | n==0)
652 return;
653
654#if defined(DIM_CHECK)
655 // non-negative size
656 if(m<0) printf("\n****** strsv_lnn_mn_libstr : m<0 : %d<0 *****\n", m);
657 if(n<0) printf("\n****** strsv_lnn_mn_libstr : n<0 : %d<0 *****\n", n);
658 // non-negative offset
659 if(ai<0) printf("\n****** strsv_lnn_mn_libstr : ai<0 : %d<0 *****\n", ai);
660 if(aj<0) printf("\n****** strsv_lnn_mn_libstr : aj<0 : %d<0 *****\n", aj);
661 if(xi<0) printf("\n****** strsv_lnn_mn_libstr : xi<0 : %d<0 *****\n", xi);
662 if(zi<0) printf("\n****** strsv_lnn_mn_libstr : zi<0 : %d<0 *****\n", zi);
663 // inside matrix
664 // A: m x k
665 if(ai+m > sA->m) printf("\n***** strsv_lnn_mn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
666 if(aj+n > sA->n) printf("\n***** strsv_lnn_mn_libstr : aj+n > col(A) : %d+%d > %d *****\n", aj, n, sA->n);
667 // x: m
668 if(xi+m > sx->m) printf("\n***** strsv_lnn_mn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
669 // z: m
670 if(zi+m > sz->m) printf("\n***** strsv_lnn_mn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
671#endif
672
673 if(ai!=0)
674 {
675 printf("\nstrsv_lnn_mn_libstr: feature not implemented yet: ai=%d\n", ai);
676 exit(1);
677 }
678
679 const int bs = 4;
680
681 int sda = sA->cn;
682 float *pA = sA->pA + aj*bs; // TODO ai
683 float *dA = sA->dA;
684 float *x = sx->pa + xi;
685 float *z = sz->pa + zi;
686
687 int ii;
688
689 if(ai==0 & aj==0)
690 {
691 if(sA->use_dA!=1)
692 {
693 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
694 for(ii=0; ii<n; ii++)
695 dA[ii] = 1.0 / dA[ii];
696 sA->use_dA = 1;
697 }
698 }
699 else
700 {
701 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
702 for(ii=0; ii<n; ii++)
703 dA[ii] = 1.0 / dA[ii];
704 sA->use_dA = 0;
705 }
706
707 if(m<n)
708 m = n;
709
710 float alpha = -1.0;
711 float beta = 1.0;
712
713 int i;
714
715 if(x!=z)
716 {
717 for(i=0; i<m; i++)
718 z[i] = x[i];
719 }
720
721 i = 0;
722 for( ; i<n-3; i+=4)
723 {
724 kernel_strsv_ln_inv_4_lib4(i, &pA[i*sda], &dA[i], z, &z[i], &z[i]);
725 }
726 if(i<n)
727 {
728 kernel_strsv_ln_inv_4_vs_lib4(i, &pA[i*sda], &dA[i], z, &z[i], &z[i], m-i, n-i);
729 i+=4;
730 }
731 for( ; i<m-3; i+=4)
732 {
733 kernel_sgemv_n_4_lib4(n, &alpha, &pA[i*sda], z, &beta, &z[i], &z[i]);
734 }
735 if(i<m)
736 {
737 kernel_sgemv_n_4_vs_lib4(n, &alpha, &pA[i*sda], z, &beta, &z[i], &z[i], m-i);
738 i+=4;
739 }
740
741 return;
742
743 }
744
745
746
747void strsv_ltn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
748 {
749
750 if(m==0)
751 return;
752
753#if defined(DIM_CHECK)
754 // non-negative size
755 if(m<0) printf("\n****** strsv_ltn_libstr : m<0 : %d<0 *****\n", m);
756 // non-negative offset
757 if(ai<0) printf("\n****** strsv_ltn_libstr : ai<0 : %d<0 *****\n", ai);
758 if(aj<0) printf("\n****** strsv_ltn_libstr : aj<0 : %d<0 *****\n", aj);
759 if(xi<0) printf("\n****** strsv_ltn_libstr : xi<0 : %d<0 *****\n", xi);
760 if(zi<0) printf("\n****** strsv_ltn_libstr : zi<0 : %d<0 *****\n", zi);
761 // inside matrix
762 // A: m x k
763 if(ai+m > sA->m) printf("\n***** strsv_ltn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
764 if(aj+m > sA->n) printf("\n***** strsv_ltn_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
765 // x: m
766 if(xi+m > sx->m) printf("\n***** strsv_ltn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
767 // z: m
768 if(zi+m > sz->m) printf("\n***** strsv_ltn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
769#endif
770
771 if(ai!=0)
772 {
773 printf("\nstrsv_ltn_libstr: feature not implemented yet: ai=%d\n", ai);
774 exit(1);
775 }
776
777 const int bs = 4;
778
779 int sda = sA->cn;
780 float *pA = sA->pA + aj*bs; // TODO ai
781 float *dA = sA->dA;
782 float *x = sx->pa + xi;
783 float *z = sz->pa + zi;
784
785 int ii;
786
787 if(ai==0 & aj==0)
788 {
789 if(sA->use_dA!=1)
790 {
791 sdiaex_lib(m, 1.0, ai, pA, sda, dA);
792 for(ii=0; ii<m; ii++)
793 dA[ii] = 1.0 / dA[ii];
794 sA->use_dA = 1;
795 }
796 }
797 else
798 {
799 sdiaex_lib(m, 1.0, ai, pA, sda, dA);
800 for(ii=0; ii<m; ii++)
801 dA[ii] = 1.0 / dA[ii];
802 sA->use_dA = 0;
803 }
804
805 int i;
806
807 if(x!=z)
808 for(i=0; i<m; i++)
809 z[i] = x[i];
810
811 i=0;
812 if(m%4==1)
813 {
814 kernel_strsv_lt_inv_1_lib4(i+1, &pA[m/bs*bs*sda+(m-i-1)*bs], sda, &dA[m-i-1], &z[m-i-1], &z[m-i-1], &z[m-i-1]);
815 i++;
816 }
817 else if(m%4==2)
818 {
819 kernel_strsv_lt_inv_2_lib4(i+2, &pA[m/bs*bs*sda+(m-i-2)*bs], sda, &dA[m-i-2], &z[m-i-2], &z[m-i-2], &z[m-i-2]);
820 i+=2;
821 }
822 else if(m%4==3)
823 {
824 kernel_strsv_lt_inv_3_lib4(i+3, &pA[m/bs*bs*sda+(m-i-3)*bs], sda, &dA[m-i-3], &z[m-i-3], &z[m-i-3], &z[m-i-3]);
825 i+=3;
826 }
827 for(; i<m-3; i+=4)
828 {
829 kernel_strsv_lt_inv_4_lib4(i+4, &pA[(m-i-4)/bs*bs*sda+(m-i-4)*bs], sda, &dA[m-i-4], &z[m-i-4], &z[m-i-4], &z[m-i-4]);
830 }
831
832 return;
833
834 }
835
836
837
838void strsv_ltn_mn_libstr(int m, int n, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
839 {
840
841 if(m==0)
842 return;
843
844#if defined(DIM_CHECK)
845 // non-negative size
846 if(m<0) printf("\n****** strsv_ltn_mn_libstr : m<0 : %d<0 *****\n", m);
847 if(n<0) printf("\n****** strsv_ltn_mn_libstr : n<0 : %d<0 *****\n", n);
848 // non-negative offset
849 if(ai<0) printf("\n****** strsv_ltn_mn_libstr : ai<0 : %d<0 *****\n", ai);
850 if(aj<0) printf("\n****** strsv_ltn_mn_libstr : aj<0 : %d<0 *****\n", aj);
851 if(xi<0) printf("\n****** strsv_ltn_mn_libstr : xi<0 : %d<0 *****\n", xi);
852 if(zi<0) printf("\n****** strsv_ltn_mn_libstr : zi<0 : %d<0 *****\n", zi);
853 // inside matrix
854 // A: m x k
855 if(ai+m > sA->m) printf("\n***** strsv_ltn_mn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
856 if(aj+n > sA->n) printf("\n***** strsv_ltn_mn_libstr : aj+n > col(A) : %d+%d > %d *****\n", aj, n, sA->n);
857 // x: m
858 if(xi+m > sx->m) printf("\n***** strsv_ltn_mn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
859 // z: m
860 if(zi+m > sz->m) printf("\n***** strsv_ltn_mn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
861#endif
862
863 if(ai!=0)
864 {
865 printf("\nstrsv_ltn_mn_libstr: feature not implemented yet: ai=%d\n", ai);
866 exit(1);
867 }
868
869 const int bs = 4;
870
871 int sda = sA->cn;
872 float *pA = sA->pA + aj*bs; // TODO ai
873 float *dA = sA->dA;
874 float *x = sx->pa + xi;
875 float *z = sz->pa + zi;
876
877 int ii;
878
879 if(ai==0 & aj==0)
880 {
881 if(sA->use_dA!=1)
882 {
883 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
884 for(ii=0; ii<n; ii++)
885 dA[ii] = 1.0 / dA[ii];
886 sA->use_dA = 1;
887 }
888 }
889 else
890 {
891 sdiaex_lib(n, 1.0, ai, pA, sda, dA);
892 for(ii=0; ii<n; ii++)
893 dA[ii] = 1.0 / dA[ii];
894 sA->use_dA = 0;
895 }
896
897 if(n>m)
898 n = m;
899
900 int i;
901
902 if(x!=z)
903 for(i=0; i<m; i++)
904 z[i] = x[i];
905
906 i=0;
907 if(n%4==1)
908 {
909 kernel_strsv_lt_inv_1_lib4(m-n+i+1, &pA[n/bs*bs*sda+(n-i-1)*bs], sda, &dA[n-i-1], &z[n-i-1], &z[n-i-1], &z[n-i-1]);
910 i++;
911 }
912 else if(n%4==2)
913 {
914 kernel_strsv_lt_inv_2_lib4(m-n+i+2, &pA[n/bs*bs*sda+(n-i-2)*bs], sda, &dA[n-i-2], &z[n-i-2], &z[n-i-2], &z[n-i-2]);
915 i+=2;
916 }
917 else if(n%4==3)
918 {
919 kernel_strsv_lt_inv_3_lib4(m-n+i+3, &pA[n/bs*bs*sda+(n-i-3)*bs], sda, &dA[n-i-3], &z[n-i-3], &z[n-i-3], &z[n-i-3]);
920 i+=3;
921 }
922 for(; i<n-3; i+=4)
923 {
924 kernel_strsv_lt_inv_4_lib4(m-n+i+4, &pA[(n-i-4)/bs*bs*sda+(n-i-4)*bs], sda, &dA[n-i-4], &z[n-i-4], &z[n-i-4], &z[n-i-4]);
925 }
926
927 return;
928
929 }
930
931
932
933void strsv_lnu_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
934 {
935 if(m==0)
936 return;
937#if defined(DIM_CHECK)
938 // non-negative size
939 if(m<0) printf("\n****** strsv_lnu_libstr : m<0 : %d<0 *****\n", m);
940 // non-negative offset
941 if(ai<0) printf("\n****** strsv_lnu_libstr : ai<0 : %d<0 *****\n", ai);
942 if(aj<0) printf("\n****** strsv_lnu_libstr : aj<0 : %d<0 *****\n", aj);
943 if(xi<0) printf("\n****** strsv_lnu_libstr : xi<0 : %d<0 *****\n", xi);
944 if(zi<0) printf("\n****** strsv_lnu_libstr : zi<0 : %d<0 *****\n", zi);
945 // inside matrix
946 // A: m x k
947 if(ai+m > sA->m) printf("\n***** strsv_lnu_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
948 if(aj+m > sA->n) printf("\n***** strsv_lnu_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
949 // x: m
950 if(xi+m > sx->m) printf("\n***** strsv_lnu_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
951 // z: m
952 if(zi+m > sz->m) printf("\n***** strsv_lnu_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
953#endif
954 printf("\n***** strsv_lnu_libstr : feature not implemented yet *****\n");
955 exit(1);
956 }
957
958
959
960void strsv_ltu_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
961 {
962 if(m==0)
963 return;
964#if defined(DIM_CHECK)
965 // non-negative size
966 if(m<0) printf("\n****** strsv_ltu_libstr : m<0 : %d<0 *****\n", m);
967 // non-negative offset
968 if(ai<0) printf("\n****** strsv_ltu_libstr : ai<0 : %d<0 *****\n", ai);
969 if(aj<0) printf("\n****** strsv_ltu_libstr : aj<0 : %d<0 *****\n", aj);
970 if(xi<0) printf("\n****** strsv_ltu_libstr : xi<0 : %d<0 *****\n", xi);
971 if(zi<0) printf("\n****** strsv_ltu_libstr : zi<0 : %d<0 *****\n", zi);
972 // inside matrix
973 // A: m x k
974 if(ai+m > sA->m) printf("\n***** strsv_ltu_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
975 if(aj+m > sA->n) printf("\n***** strsv_ltu_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
976 // x: m
977 if(xi+m > sx->m) printf("\n***** strsv_ltu_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
978 // z: m
979 if(zi+m > sz->m) printf("\n***** strsv_ltu_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
980#endif
981 printf("\n***** strsv_ltu_libstr : feature not implemented yet *****\n");
982 exit(1);
983 }
984
985
986
987void strsv_unn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
988 {
989 if(m==0)
990 return;
991#if defined(DIM_CHECK)
992 // non-negative size
993 if(m<0) printf("\n****** strsv_unn_libstr : m<0 : %d<0 *****\n", m);
994 // non-negative offset
995 if(ai<0) printf("\n****** strsv_unn_libstr : ai<0 : %d<0 *****\n", ai);
996 if(aj<0) printf("\n****** strsv_unn_libstr : aj<0 : %d<0 *****\n", aj);
997 if(xi<0) printf("\n****** strsv_unn_libstr : xi<0 : %d<0 *****\n", xi);
998 if(zi<0) printf("\n****** strsv_unn_libstr : zi<0 : %d<0 *****\n", zi);
999 // inside matrix
1000 // A: m x k
1001 if(ai+m > sA->m) printf("\n***** strsv_unn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
1002 if(aj+m > sA->n) printf("\n***** strsv_unn_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
1003 // x: m
1004 if(xi+m > sx->m) printf("\n***** strsv_unn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
1005 // z: m
1006 if(zi+m > sz->m) printf("\n***** strsv_unn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
1007#endif
1008 printf("\n***** strsv_unn_libstr : feature not implemented yet *****\n");
1009 exit(1);
1010 }
1011
1012
1013
1014void strsv_utn_libstr(int m, struct s_strmat *sA, int ai, int aj, struct s_strvec *sx, int xi, struct s_strvec *sz, int zi)
1015 {
1016 if(m==0)
1017 return;
1018#if defined(DIM_CHECK)
1019 // non-negative size
1020 if(m<0) printf("\n****** strsv_utn_libstr : m<0 : %d<0 *****\n", m);
1021 // non-negative offset
1022 if(ai<0) printf("\n****** strsv_utn_libstr : ai<0 : %d<0 *****\n", ai);
1023 if(aj<0) printf("\n****** strsv_utn_libstr : aj<0 : %d<0 *****\n", aj);
1024 if(xi<0) printf("\n****** strsv_utn_libstr : xi<0 : %d<0 *****\n", xi);
1025 if(zi<0) printf("\n****** strsv_utn_libstr : zi<0 : %d<0 *****\n", zi);
1026 // inside matrix
1027 // A: m x k
1028 if(ai+m > sA->m) printf("\n***** strsv_utn_libstr : ai+m > row(A) : %d+%d > %d *****\n", ai, m, sA->m);
1029 if(aj+m > sA->n) printf("\n***** strsv_utn_libstr : aj+m > col(A) : %d+%d > %d *****\n", aj, m, sA->n);
1030 // x: m
1031 if(xi+m > sx->m) printf("\n***** strsv_utn_libstr : xi+m > size(x) : %d+%d > %d *****\n", xi, m, sx->m);
1032 // z: m
1033 if(zi+m > sz->m) printf("\n***** strsv_utn_libstr : zi+m > size(z) : %d+%d > %d *****\n", zi, m, sz->m);
1034#endif
1035 printf("\n***** strsv_utn_libstr : feature not implemented yet *****\n");
1036 exit(1);
1037 }
1038
1039
1040
1041#else
1042
1043#error : wrong LA choice
1044
1045#endif