blob: e2986048ac634d925307d67a20bf01505140f39a [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31#include <sys/time.h>
32
33#include "../include/blasfeo_common.h"
34#include "../include/blasfeo_i_aux_ext_dep.h"
35#include "../include/blasfeo_v_aux_ext_dep.h"
36#include "../include/blasfeo_s_aux_ext_dep.h"
37#include "../include/blasfeo_s_aux.h"
38#include "../include/blasfeo_s_kernel.h"
39#include "../include/blasfeo_s_blas.h"
40
41
42int main()
43 {
44
45 printf("\nExample of LU factorization and backsolve\n\n");
46
47#if defined(LA_HIGH_PERFORMANCE)
48
49 printf("\nLA provided by BLASFEO\n\n");
50
51#elif defined(LA_REFERENCE)
52
53 printf("\nLA provided by REFERENCE\n\n");
54
55#elif defined(LA_BLAS)
56
57 printf("\nLA provided by BLAS\n\n");
58
59#else
60
61 printf("\nLA provided by ???\n\n");
62 exit(2);
63
64#endif
65
66 int ii;
67
68 int n = 16;
69
70 //
71 // matrices in column-major format
72 //
73
74 float *A; s_zeros(&A, n, n);
75 for(ii=0; ii<n*n; ii++) A[ii] = ii;
76// s_print_mat(n, n, A, n);
77
78 // spd matrix
79 float *B; s_zeros(&B, n, n);
80 for(ii=0; ii<n; ii++) B[ii*(n+1)] = 1.0;
81// s_print_mat(n, n, B, n);
82
83 // identity
84 float *I; s_zeros(&I, n, n);
85 for(ii=0; ii<n; ii++) I[ii*(n+1)] = 1.0;
86// s_print_mat(n, n, B, n);
87
88 // result matrix
89 float *D; s_zeros(&D, n, n);
90// s_print_mat(n, n, D, n);
91
92 // permutation indeces
93 int *ipiv; int_zeros(&ipiv, n, 1);
94
95 //
96 // matrices in matrix struct format
97 //
98
99 // work space enough for 5 matrix structs for size n times n
100 int size_strmat = 5*s_size_strmat(n, n);
101 void *memory_strmat; v_zeros_align(&memory_strmat, size_strmat);
102 char *ptr_memory_strmat = (char *) memory_strmat;
103
104 struct s_strmat sA;
105// s_allocate_strmat(n, n, &sA);
106 s_create_strmat(n, n, &sA, ptr_memory_strmat);
107 ptr_memory_strmat += sA.memory_size;
108 // convert from column major matrix to strmat
109 s_cvt_mat2strmat(n, n, A, n, &sA, 0, 0);
110 printf("\nA = \n");
111 s_print_strmat(n, n, &sA, 0, 0);
112
113 struct s_strmat sB;
114// s_allocate_strmat(n, n, &sB);
115 s_create_strmat(n, n, &sB, ptr_memory_strmat);
116 ptr_memory_strmat += sB.memory_size;
117 // convert from column major matrix to strmat
118 s_cvt_mat2strmat(n, n, B, n, &sB, 0, 0);
119 printf("\nB = \n");
120 s_print_strmat(n, n, &sB, 0, 0);
121
122 struct s_strmat sI;
123// s_allocate_strmat(n, n, &sI);
124 s_create_strmat(n, n, &sI, ptr_memory_strmat);
125 ptr_memory_strmat += sI.memory_size;
126 // convert from column major matrix to strmat
127
128 struct s_strmat sD;
129// s_allocate_strmat(n, n, &sD);
130 s_create_strmat(n, n, &sD, ptr_memory_strmat);
131 ptr_memory_strmat += sD.memory_size;
132
133 struct s_strmat sLU;
134// s_allocate_strmat(n, n, &sD);
135 s_create_strmat(n, n, &sLU, ptr_memory_strmat);
136 ptr_memory_strmat += sLU.memory_size;
137
138 sgemm_nt_libstr(n, n, n, 1.0, &sA, 0, 0, &sA, 0, 0, 1.0, &sB, 0, 0, &sD, 0, 0);
139 printf("\nB+A*A' = \n");
140 s_print_strmat(n, n, &sD, 0, 0);
141
142// sgetrf_nopivot_libstr(n, n, &sD, 0, 0, &sD, 0, 0);
143 sgetrf_libstr(n, n, &sD, 0, 0, &sLU, 0, 0, ipiv);
144 printf("\nLU = \n");
145 s_print_strmat(n, n, &sLU, 0, 0);
146 printf("\nipiv = \n");
147 int_print_mat(1, n, ipiv, 1);
148
149#if 0 // solve P L U X = P B
150 s_cvt_mat2strmat(n, n, I, n, &sI, 0, 0);
151 printf("\nI = \n");
152 s_print_strmat(n, n, &sI, 0, 0);
153
154 srowpe_libstr(n, ipiv, &sI);
155 printf("\nperm(I) = \n");
156 s_print_strmat(n, n, &sI, 0, 0);
157
158 strsm_llnu_libstr(n, n, 1.0, &sLU, 0, 0, &sI, 0, 0, &sD, 0, 0);
159 printf("\nperm(inv(L)) = \n");
160 s_print_strmat(n, n, &sD, 0, 0);
161 strsm_lunn_libstr(n, n, 1.0, &sLU, 0, 0, &sD, 0, 0, &sD, 0, 0);
162 printf("\ninv(A) = \n");
163 s_print_strmat(n, n, &sD, 0, 0);
164
165 // convert from strmat to column major matrix
166 s_cvt_strmat2mat(n, n, &sD, 0, 0, D, n);
167#else // solve X^T (P L U)^T = B^T P^T
168 s_cvt_tran_mat2strmat(n, n, I, n, &sI, 0, 0);
169 printf("\nI' = \n");
170 s_print_strmat(n, n, &sI, 0, 0);
171
172 scolpe_libstr(n, ipiv, &sB);
173 printf("\nperm(I') = \n");
174 s_print_strmat(n, n, &sB, 0, 0);
175
176 strsm_rltu_libstr(n, n, 1.0, &sLU, 0, 0, &sB, 0, 0, &sD, 0, 0);
177 printf("\nperm(inv(L')) = \n");
178 s_print_strmat(n, n, &sD, 0, 0);
179 strsm_rutn_libstr(n, n, 1.0, &sLU, 0, 0, &sD, 0, 0, &sD, 0, 0);
180 printf("\ninv(A') = \n");
181 s_print_strmat(n, n, &sD, 0, 0);
182
183 // convert from strmat to column major matrix
184 s_cvt_tran_strmat2mat(n, n, &sD, 0, 0, D, n);
185#endif
186
187 // print matrix in column-major format
188 printf("\ninv(A) = \n");
189 s_print_mat(n, n, D, n);
190
191
192
193 //
194 // free memory
195 //
196
197 s_free(A);
198 s_free(B);
199 s_free(D);
200 s_free(I);
201 int_free(ipiv);
202// s_free_strmat(&sA);
203// s_free_strmat(&sB);
204// s_free_strmat(&sD);
205// s_free_strmat(&sI);
206 v_free_align(memory_strmat);
207
208 return 0;
209
210 }
211