blob: 7d02d3627037f7bd6d733a167e32f5e574df2f6b [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31#include <math.h>
32
33#include "../include/blasfeo_common.h"
34#include "../include/blasfeo_s_aux.h"
35#include "../include/blasfeo_s_kernel.h"
36
37
38
39/****************************
40* old interface
41****************************/
42
43void ssyrk_spotrf_nt_l_lib(int m, int n, int k, float *pA, int sda, float *pB, int sdb, float *pC, int sdc, float *pD, int sdd, float *inv_diag_D)
44 {
45
46 if(m<=0 || n<=0)
47 return;
48
49 int alg = 1; // XXX
50
51 const int bs = 4;
52
53 int i, j, l;
54
55 i = 0;
56
57 for(; i<m-3; i+=4)
58 {
59 j = 0;
60 for(; j<i && j<n-3; j+=4)
61 {
62 kernel_sgemm_strsm_nt_rl_inv_4x4_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &inv_diag_D[j]);
63 }
64 if(j<n)
65 {
66 if(i<j) // dgemm
67 {
68 kernel_sgemm_strsm_nt_rl_inv_4x4_vs_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &inv_diag_D[j], m-i, n-j);
69 }
70 else // dsyrk
71 {
72 if(j<n-3)
73 {
74 kernel_ssyrk_spotrf_nt_l_4x4_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &inv_diag_D[j]);
75 }
76 else
77 {
78 kernel_ssyrk_spotrf_nt_l_4x4_vs_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &inv_diag_D[j], m-i, n-j);
79 }
80 }
81 }
82 }
83 if(m>i)
84 {
85 goto left_4;
86 }
87
88 // common return if i==m
89 return;
90
91 // clean up loops definitions
92
93 left_4:
94 j = 0;
95 for(; j<i && j<n-3; j+=4)
96 {
97 kernel_sgemm_strsm_nt_rl_inv_4x4_vs_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &inv_diag_D[j], m-i, n-j);
98 }
99 if(j<n)
100 {
101 if(j<i) // dgemm
102 {
103 kernel_sgemm_strsm_nt_rl_inv_4x4_vs_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &inv_diag_D[j], m-i, n-j);
104 }
105 else // dsyrk
106 {
107 kernel_ssyrk_spotrf_nt_l_4x4_vs_lib4(k, &pA[i*sda], &pB[j*sdb], j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &inv_diag_D[j], m-i, n-j);
108 }
109 }
110 return;
111
112 }
113
114
115
116void sgetrf_nn_nopivot_lib(int m, int n, float *pC, int sdc, float *pD, int sdd, float *inv_diag_D)
117 {
118
119 if(m<=0 || n<=0)
120 return;
121
122 const int bs = 4;
123
124 int ii, jj, ie;
125
126 // main loop
127 ii = 0;
128 for( ; ii<m-3; ii+=4)
129 {
130 jj = 0;
131 // solve lower
132 ie = n<ii ? n : ii; // ie is multiple of 4
133 for( ; jj<ie-3; jj+=4)
134 {
135 kernel_strsm_nn_ru_inv_4x4_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[jj*bs+jj*sdd], &inv_diag_D[jj]);
136 }
137 if(jj<ie)
138 {
139 kernel_strsm_nn_ru_inv_4x4_vs_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[jj*bs+jj*sdd], &inv_diag_D[jj], m-ii, ie-jj);
140 jj+=4;
141 }
142 // factorize
143 if(jj<n-3)
144 {
145 kernel_sgetrf_nn_4x4_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &inv_diag_D[jj]);
146 jj+=4;
147 }
148 else if(jj<n)
149 {
150 kernel_sgetrf_nn_4x4_vs_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &inv_diag_D[jj], m-ii, n-jj);
151 jj+=4;
152 }
153 // solve upper
154 for( ; jj<n-3; jj+=4)
155 {
156 kernel_strsm_nn_ll_one_4x4_lib4(ii, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[ii*bs+ii*sdd]);
157 }
158 if(jj<n)
159 {
160 kernel_strsm_nn_ll_one_4x4_vs_lib4(ii, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[ii*bs+ii*sdd], m-ii, n-jj);
161 }
162 }
163 if(m>ii)
164 {
165 goto left_4;
166 }
167
168 // common return if i==m
169 return;
170
171 left_4:
172 jj = 0;
173 // solve lower
174 ie = n<ii ? n : ii; // ie is multiple of 4
175 for( ; jj<ie; jj+=4)
176 {
177 kernel_strsm_nn_ru_inv_4x4_vs_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[jj*bs+jj*sdd], &inv_diag_D[jj], m-ii, ie-jj);
178 }
179 // factorize
180 if(jj<n)
181 {
182 kernel_sgetrf_nn_4x4_vs_lib4(jj, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &inv_diag_D[jj], m-ii, n-jj);
183 jj+=4;
184 }
185 // solve upper
186 for( ; jj<n; jj+=4)
187 {
188 kernel_strsm_nn_ll_one_4x4_vs_lib4(ii, &pD[ii*sdd], &pD[jj*bs], sdd, &pC[jj*bs+ii*sdc], &pD[jj*bs+ii*sdd], &pD[ii*bs+ii*sdd], m-ii, n-jj);
189 }
190 return;
191
192 }
193
194
195
196void sgetrf_nn_lib(int m, int n, float *pC, int sdc, float *pD, int sdd, float *inv_diag_D, int *ipiv)
197 {
198
199 if(m<=0)
200 return;
201
202 const int bs = 4;
203
204 int ii, jj, i0, i1, j0, ll, p;
205
206 float d1 = 1.0;
207 float dm1 = -1.0;
208
209// // needs to perform row-excanges on the yet-to-be-factorized matrix too
210// if(pC!=pD)
211// sgecp_lib(m, n, 1.0, 0, pC, sdc, 0, pD, sdd);
212
213 // minimum matrix size
214 p = n<m ? n : m; // XXX
215
216 // main loop
217 // 4 columns at a time
218 jj = 0;
219 for(; jj<p-3; jj+=4) // XXX
220 {
221 // pivot & factorize & solve lower
222 ii = jj;
223 i0 = ii;
224 for( ; ii<m-3; ii+=4)
225 {
226 kernel_sgemm_nn_4x4_lib4(jj, &dm1, &pD[ii*sdd], &pD[jj*bs], sdd, &d1, &pD[jj*bs+ii*sdd], &pD[jj*bs+ii*sdd]);
227 }
228 if(m-ii>0)
229 {
230 kernel_sgemm_nn_4x4_vs_lib4(jj, &dm1, &pD[ii*sdd], &pD[jj*bs], sdd, &d1, &pD[jj*bs+ii*sdd], &pD[jj*bs+ii*sdd], m-ii, 4);
231 }
232 kernel_sgetrf_pivot_4_lib4(m-i0, &pD[jj*bs+i0*sdd], sdd, &inv_diag_D[jj], &ipiv[i0]);
233 ipiv[i0+0] += i0;
234 if(ipiv[i0+0]!=i0+0)
235 {
236 srowsw_lib(jj, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs);
237 srowsw_lib(n-jj-4, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs+(jj+4)*bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs+(jj+4)*bs);
238 }
239 ipiv[i0+1] += i0;
240 if(ipiv[i0+1]!=i0+1)
241 {
242 srowsw_lib(jj, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs);
243 srowsw_lib(n-jj-4, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs+(jj+4)*bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs+(jj+4)*bs);
244 }
245 ipiv[i0+2] += i0;
246 if(ipiv[i0+2]!=i0+2)
247 {
248 srowsw_lib(jj, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs);
249 srowsw_lib(n-jj-4, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs+(jj+4)*bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs+(jj+4)*bs);
250 }
251 ipiv[i0+3] += i0;
252 if(ipiv[i0+3]!=i0+3)
253 {
254 srowsw_lib(jj, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs);
255 srowsw_lib(n-jj-4, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs+(jj+4)*bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs+(jj+4)*bs);
256 }
257
258 // solve upper
259 ll = jj+4;
260 for( ; ll<n-3; ll+=4)
261 {
262 kernel_strsm_nn_ll_one_4x4_lib4(i0, &pD[i0*sdd], &pD[ll*bs], sdd, &pD[ll*bs+i0*sdd], &pD[ll*bs+i0*sdd], &pD[i0*bs+i0*sdd]);
263 }
264 if(n-ll>0)
265 {
266 kernel_strsm_nn_ll_one_4x4_vs_lib4(i0, &pD[i0*sdd], &pD[ll*bs], sdd, &pD[ll*bs+i0*sdd], &pD[ll*bs+i0*sdd], &pD[i0*bs+i0*sdd], 4, n-ll);
267 }
268 }
269 if(m>=n)
270 {
271 if(n-jj>0)
272 {
273 goto left_n_4;
274 }
275 }
276 else
277 {
278 if(m-jj>0)
279 {
280 goto left_m_4;
281 }
282 }
283
284 // common return if jj==n
285 return;
286
287 // clean up
288
289 left_n_4:
290 // 1-4 columns at a time
291 // pivot & factorize & solve lower
292 ii = jj;
293 i0 = ii;
294 for( ; ii<m; ii+=4)
295 {
296 kernel_sgemm_nn_4x4_vs_lib4(jj, &dm1, &pD[ii*sdd], &pD[jj*bs], sdd, &d1, &pD[jj*bs+ii*sdd], &pD[jj*bs+ii*sdd], m-ii, n-jj);
297 }
298 kernel_sgetrf_pivot_4_vs_lib4(m-i0, n-jj, &pD[jj*bs+i0*sdd], sdd, &inv_diag_D[jj], &ipiv[i0]);
299 ipiv[i0+0] += i0;
300 if(ipiv[i0+0]!=i0+0)
301 {
302 srowsw_lib(jj, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs);
303 srowsw_lib(n-jj-4, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs+(jj+4)*bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs+(jj+4)*bs);
304 }
305 if(n-jj>1)
306 {
307 ipiv[i0+1] += i0;
308 if(ipiv[i0+1]!=i0+1)
309 {
310 srowsw_lib(jj, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs);
311 srowsw_lib(n-jj-4, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs+(jj+4)*bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs+(jj+4)*bs);
312 }
313 if(n-jj>2)
314 {
315 ipiv[i0+2] += i0;
316 if(ipiv[i0+2]!=i0+2)
317 {
318 srowsw_lib(jj, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs);
319 srowsw_lib(n-jj-4, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs+(jj+4)*bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs+(jj+4)*bs);
320 }
321 if(n-jj>3)
322 {
323 ipiv[i0+3] += i0;
324 if(ipiv[i0+3]!=i0+3)
325 {
326 srowsw_lib(jj, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs);
327 srowsw_lib(n-jj-4, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs+(jj+4)*bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs+(jj+4)*bs);
328 }
329 }
330 }
331 }
332
333 // solve upper
334 if(0) // there is no upper
335 {
336 ll = jj+4;
337 for( ; ll<n; ll+=4)
338 {
339 kernel_strsm_nn_ll_one_4x4_vs_lib4(i0, &pD[i0*sdd], &pD[ll*bs], sdd, &pD[ll*bs+i0*sdd], &pD[ll*bs+i0*sdd], &pD[i0*bs+i0*sdd], m-i0, n-ll);
340 }
341 }
342 return;
343
344
345 left_m_4:
346 // 1-4 rows at a time
347 // pivot & factorize & solve lower
348 ii = jj;
349 i0 = ii;
350 kernel_sgemm_nn_4x4_vs_lib4(jj, &dm1, &pD[ii*sdd], &pD[jj*bs], sdd, &d1, &pD[jj*bs+ii*sdd], &pD[jj*bs+ii*sdd], m-ii, n-jj);
351 kernel_sgetrf_pivot_4_vs_lib4(m-i0, n-jj, &pD[jj*bs+i0*sdd], sdd, &inv_diag_D[jj], &ipiv[i0]);
352 ipiv[i0+0] += i0;
353 if(ipiv[i0+0]!=i0+0)
354 {
355 srowsw_lib(jj, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs);
356 srowsw_lib(n-jj-4, pD+(i0+0)/bs*bs*sdd+(i0+0)%bs+(jj+4)*bs, pD+(ipiv[i0+0])/bs*bs*sdd+(ipiv[i0+0])%bs+(jj+4)*bs);
357 }
358 if(m-i0>1)
359 {
360 ipiv[i0+1] += i0;
361 if(ipiv[i0+1]!=i0+1)
362 {
363 srowsw_lib(jj, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs);
364 srowsw_lib(n-jj-4, pD+(i0+1)/bs*bs*sdd+(i0+1)%bs+(jj+4)*bs, pD+(ipiv[i0+1])/bs*bs*sdd+(ipiv[i0+1])%bs+(jj+4)*bs);
365 }
366 if(m-i0>2)
367 {
368 ipiv[i0+2] += i0;
369 if(ipiv[i0+2]!=i0+2)
370 {
371 srowsw_lib(jj, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs);
372 srowsw_lib(n-jj-4, pD+(i0+2)/bs*bs*sdd+(i0+2)%bs+(jj+4)*bs, pD+(ipiv[i0+2])/bs*bs*sdd+(ipiv[i0+2])%bs+(jj+4)*bs);
373 }
374 if(m-i0>3)
375 {
376 ipiv[i0+3] += i0;
377 if(ipiv[i0+3]!=i0+3)
378 {
379 srowsw_lib(jj, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs);
380 srowsw_lib(n-jj-4, pD+(i0+3)/bs*bs*sdd+(i0+3)%bs+(jj+4)*bs, pD+(ipiv[i0+3])/bs*bs*sdd+(ipiv[i0+3])%bs+(jj+4)*bs);
381 }
382 }
383 }
384 }
385
386 // solve upper
387 ll = jj+4;
388 for( ; ll<n; ll+=4)
389 {
390 kernel_strsm_nn_ll_one_4x4_vs_lib4(i0, &pD[i0*sdd], &pD[ll*bs], sdd, &pD[ll*bs+i0*sdd], &pD[ll*bs+i0*sdd], &pD[i0*bs+i0*sdd], m-i0, n-ll);
391 }
392 return;
393
394 }
395
396
397
398/****************************
399* new interface
400****************************/
401
402
403
404#if defined(LA_HIGH_PERFORMANCE)
405
406
407
408// dpotrf
409void spotrf_l_libstr(int m, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
410 {
411
412 if(m<=0)
413 return;
414
415 if(ci!=0 | di!=0)
416 {
417 printf("\nspotrf_l_libstr: feature not implemented yet: ci=%d, di=%d\n", ci, di);
418 exit(1);
419 }
420
421 const int bs = 4;
422
423 int sdc = sC->cn;
424 int sdd = sD->cn;
425 float *pC = sC->pA + cj*bs;
426 float *pD = sD->pA + dj*bs;
427 float *dD = sD->dA;
428 if(di==0 && dj==0) // XXX what to do if di and dj are not zero
429 sD->use_dA = 1;
430 else
431 sD->use_dA = 0;
432
433 int i, j, l;
434
435 i = 0;
436 for(; i<m-3; i+=4)
437 {
438 j = 0;
439 for(; j<i; j+=4)
440 {
441 kernel_strsm_nt_rl_inv_4x4_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j]);
442 }
443 kernel_spotrf_nt_l_4x4_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &dD[j]);
444 }
445 if(m>i)
446 {
447 goto left_4;
448 }
449
450 // common return if i==m
451 return;
452
453 // clean up loops definitions
454
455 left_4: // 1 - 3
456 j = 0;
457 for(; j<i; j+=4)
458 {
459 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j], m-i, m-j);
460 }
461 kernel_spotrf_nt_l_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &dD[j], m-i, m-j);
462 return;
463
464 return;
465 }
466
467
468
469// dpotrf
470void spotrf_l_mn_libstr(int m, int n, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
471 {
472
473 if(m<=0 || n<=0)
474 return;
475
476 if(ci!=0 | di!=0)
477 {
478 printf("\nspotrf_l_libstr: feature not implemented yet: ci=%d, di=%d\n", ci, di);
479 exit(1);
480 }
481
482 const int bs = 4;
483
484 int sdc = sC->cn;
485 int sdd = sD->cn;
486 float *pC = sC->pA + cj*bs;
487 float *pD = sD->pA + dj*bs;
488 float *dD = sD->dA;
489 if(di==0 && dj==0) // XXX what to do if di and dj are not zero
490 sD->use_dA = 1;
491 else
492 sD->use_dA = 0;
493
494 int i, j, l;
495
496 i = 0;
497 for(; i<m-3; i+=4)
498 {
499 j = 0;
500 for(; j<i && j<n-3; j+=4)
501 {
502 kernel_strsm_nt_rl_inv_4x4_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j]);
503 }
504 if(j<n)
505 {
506 if(i<j) // dtrsm
507 {
508 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j], m-i, n-j);
509 }
510 else // dpotrf
511 {
512 if(j<n-3)
513 {
514 kernel_spotrf_nt_l_4x4_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &dD[j]);
515 }
516 else
517 {
518 kernel_spotrf_nt_l_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &dD[j], m-i, n-j);
519 }
520 }
521 }
522 }
523 if(m>i)
524 {
525 goto left_4;
526 }
527
528 // common return if i==m
529 return;
530
531 // clean up loops definitions
532
533 left_4:
534 j = 0;
535 for(; j<i && j<n-3; j+=4)
536 {
537 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j], m-i, n-j);
538 }
539 if(j<n)
540 {
541 if(j<i) // dtrsm
542 {
543 kernel_strsm_nt_rl_inv_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+i*sdc], &pD[j*bs+i*sdd], &pD[j*bs+j*sdd], &dD[j], m-i, n-j);
544 }
545 else // dpotrf
546 {
547 kernel_spotrf_nt_l_4x4_vs_lib4(j, &pD[i*sdd], &pD[j*sdd], &pC[j*bs+j*sdc], &pD[j*bs+j*sdd], &dD[j], m-i, n-j);
548 }
549 }
550 return;
551
552 return;
553 }
554
555
556
557// dsyrk dpotrf
558void ssyrk_spotrf_ln_libstr(int m, int n, int k, struct s_strmat *sA, int ai, int aj, struct s_strmat *sB, int bi, int bj, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
559 {
560 if(ai!=0 | bi!=0 | ci!=0 | di!=0)
561 {
562 printf("\nssyrk_spotrf_ln_libstr: feature not implemented yet: ai=%d, bi=%d, ci=%d, di=%d\n", ai, bi, ci, di);
563 exit(1);
564 }
565 const int bs = 4;
566 int sda = sA->cn;
567 int sdb = sB->cn;
568 int sdc = sC->cn;
569 int sdd = sD->cn;
570 float *pA = sA->pA + aj*bs;
571 float *pB = sB->pA + bj*bs;
572 float *pC = sC->pA + cj*bs;
573 float *pD = sD->pA + dj*bs;
574 float *dD = sD->dA; // XXX what to do if di and dj are not zero
575 ssyrk_spotrf_nt_l_lib(m, n, k, pA, sda, pB, sdb, pC, sdc, pD, sdd, dD);
576 if(di==0 && dj==0)
577 sD->use_dA = 1;
578 else
579 sD->use_dA = 0;
580 return;
581 }
582
583
584
585// dgetrf without pivoting
586void sgetrf_nopivot_libstr(int m, int n, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
587 {
588 if(ci!=0 | di!=0)
589 {
590 printf("\nsgetf_nopivot_libstr: feature not implemented yet: ci=%d, di=%d\n", ci, di);
591 exit(1);
592 }
593 const int bs = 4;
594 int sdc = sC->cn;
595 int sdd = sD->cn;
596 float *pC = sC->pA + cj*bs;
597 float *pD = sD->pA + dj*bs;
598 float *dD = sD->dA; // XXX what to do if di and dj are not zero
599 sgetrf_nn_nopivot_lib(m, n, pC, sdc, pD, sdd, dD);
600 if(di==0 && dj==0)
601 sD->use_dA = 1;
602 else
603 sD->use_dA = 0;
604 return;
605 }
606
607
608
609
610// dgetrf pivoting
611void sgetrf_libstr(int m, int n, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj, int *ipiv)
612 {
613 if(ci!=0 | di!=0)
614 {
615 printf("\nsgetrf_libstr: feature not implemented yet: ci=%d, di=%d\n", ci, di);
616 exit(1);
617 }
618 const int bs = 4;
619 int sdc = sC->cn;
620 int sdd = sD->cn;
621 float *pC = sC->pA + cj*bs;
622 float *pD = sD->pA + dj*bs;
623 float *dD = sD->dA; // XXX what to do if di and dj are not zero
624 // needs to perform row-excanges on the yet-to-be-factorized matrix too
625 if(pC!=pD)
626 sgecp_libstr(m, n, sC, ci, cj, sD, di, dj);
627 sgetrf_nn_lib(m, n, pC, sdc, pD, sdd, dD, ipiv);
628 if(di==0 && dj==0)
629 sD->use_dA = 1;
630 else
631 sD->use_dA = 0;
632 return;
633 }
634
635
636
637int sgeqrf_work_size_libstr(int m, int n)
638 {
639 printf("\nsgeqrf_work_size_libstr: feature not implemented yet\n");
640 exit(1);
641 return 0;
642 }
643
644
645
646void sgeqrf_libstr(int m, int n, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj, void *work)
647 {
648 if(m<=0 | n<=0)
649 return;
650 printf("\nsgeqrf_libstr: feature not implemented yet\n");
651 exit(1);
652 return;
653 }
654
655
656
657#else
658
659#error : wrong LA choice
660
661#endif
662
663
664