blob: 2731d1f72499dbe7c87e74198a54bddfbf135ced [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31
32#include "../include/blasfeo_common.h"
33#include "../include/blasfeo_d_kernel.h"
34
35
36
37/****************************
38* old interface
39****************************/
40
41void dgemm_diag_left_lib(int m, int n, double alpha, double *dA, double *pB, int sdb, double beta, double *pC, int sdc, double *pD, int sdd)
42 {
43
44 if(m<=0 || n<=0)
45 return;
46
47 const int bs = 4;
48
49 int ii;
50
51 ii = 0;
52 if(beta==0.0)
53 {
54 for( ; ii<m-3; ii+=4)
55 {
56 kernel_dgemm_diag_left_4_a0_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &pD[ii*sdd]);
57 }
58 }
59 else
60 {
61 for( ; ii<m-3; ii+=4)
62 {
63 kernel_dgemm_diag_left_4_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
64 }
65 }
66 if(m-ii>0)
67 {
68 if(m-ii==1)
69 kernel_dgemm_diag_left_1_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
70 else if(m-ii==2)
71 kernel_dgemm_diag_left_2_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
72 else // if(m-ii==3)
73 kernel_dgemm_diag_left_3_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
74 }
75
76 }
77
78
79
80void dgemm_diag_right_lib(int m, int n, double alpha, double *pA, int sda, double *dB, double beta, double *pC, int sdc, double *pD, int sdd)
81 {
82
83 if(m<=0 || n<=0)
84 return;
85
86 const int bs = 4;
87
88 int ii;
89
90 ii = 0;
91 if(beta==0.0)
92 {
93 for( ; ii<n-3; ii+=4)
94 {
95 kernel_dgemm_diag_right_4_a0_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &pD[ii*bs], sdd);
96 }
97 }
98 else
99 {
100 for( ; ii<n-3; ii+=4)
101 {
102 kernel_dgemm_diag_right_4_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
103 }
104 }
105 if(n-ii>0)
106 {
107 if(n-ii==1)
108 kernel_dgemm_diag_right_1_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
109 else if(n-ii==2)
110 kernel_dgemm_diag_right_2_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
111 else // if(n-ii==3)
112 kernel_dgemm_diag_right_3_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
113 }
114
115 }
116
117
118
119/****************************
120* new interface
121****************************/
122
123
124
125#if defined(LA_HIGH_PERFORMANCE)
126
127
128
129// dgemm with A diagonal matrix (stored as strvec)
130void dgemm_l_diag_libstr(int m, int n, double alpha, struct d_strvec *sA, int ai, struct d_strmat *sB, int bi, int bj, double beta, struct d_strmat *sC, int ci, int cj, struct d_strmat *sD, int di, int dj)
131 {
132 if(m<=0 | n<=0)
133 return;
134 if(bi!=0 | ci!=0 | di!=0)
135 {
136 printf("\ndgemm_l_diag_libstr: feature not implemented yet: bi=%d, ci=%d, di=%d\n", bi, ci, di);
137 exit(1);
138 }
139 const int bs = 4;
140 int sdb = sB->cn;
141 int sdc = sC->cn;
142 int sdd = sD->cn;
143 double *dA = sA->pa + ai;
144 double *pB = sB->pA + bj*bs;
145 double *pC = sC->pA + cj*bs;
146 double *pD = sD->pA + dj*bs;
147 dgemm_diag_left_lib(m, n, alpha, dA, pB, sdb, beta, pC, sdc, pD, sdd);
148 return;
149 }
150
151
152
153// dgemm with B diagonal matrix (stored as strvec)
154void dgemm_r_diag_libstr(int m, int n, double alpha, struct d_strmat *sA, int ai, int aj, struct d_strvec *sB, int bi, double beta, struct d_strmat *sC, int ci, int cj, struct d_strmat *sD, int di, int dj)
155 {
156 if(m<=0 | n<=0)
157 return;
158 if(ai!=0 | ci!=0 | di!=0)
159 {
160 printf("\ndgemm_r_diag_libstr: feature not implemented yet: ai=%d, ci=%d, di=%d\n", ai, ci, di);
161 exit(1);
162 }
163 const int bs = 4;
164 int sda = sA->cn;
165 int sdc = sC->cn;
166 int sdd = sD->cn;
167 double *pA = sA->pA + aj*bs;
168 double *dB = sB->pa + bi;
169 double *pC = sC->pA + cj*bs;
170 double *pD = sD->pA + dj*bs;
171 dgemm_diag_right_lib(m, n, alpha, pA, sda, dB, beta, pC, sdc, pD, sdd);
172 return;
173 }
174
175
176
177#else
178
179#error : wrong LA choice
180
181#endif
182
183
184