blob: 03192123e3bdb3ecd8004cb015367221599be49f [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31
32#include "../include/blasfeo_common.h"
33#include "../include/blasfeo_s_kernel.h"
34
35
36
37#if defined(LA_HIGH_PERFORMANCE)
38
39
40
41// dgemm with A diagonal matrix (stored as strvec)
42void sgemm_l_diag_libstr(int m, int n, float alpha, struct s_strvec *sA, int ai, struct s_strmat *sB, int bi, int bj, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
43 {
44
45 if(m<=0 | n<=0)
46 return;
47
48 if(bi!=0 | ci!=0 | di!=0)
49 {
50 printf("\nsgemm_l_diag_libstr: feature not implemented yet: bi=%d, ci=%d, di=%d\n", bi, ci, di);
51 exit(1);
52 }
53
54 const int bs = 4;
55
56 int sdb = sB->cn;
57 int sdc = sC->cn;
58 int sdd = sD->cn;
59 float *dA = sA->pa + ai;
60 float *pB = sB->pA + bj*bs;
61 float *pC = sC->pA + cj*bs;
62 float *pD = sD->pA + dj*bs;
63
64// sgemm_diag_left_lib(m, n, alpha, dA, pB, sdb, beta, pC, sdc, pD, sdd);
65 int ii;
66
67 ii = 0;
68 if(beta==0.0)
69 {
70 for( ; ii<m-3; ii+=4)
71 {
72 kernel_sgemm_diag_left_4_a0_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &pD[ii*sdd]);
73 }
74 }
75 else
76 {
77 for( ; ii<m-3; ii+=4)
78 {
79 kernel_sgemm_diag_left_4_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
80 }
81 }
82 if(m-ii>0)
83 {
84 if(m-ii==1)
85 kernel_sgemm_diag_left_1_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
86 else if(m-ii==2)
87 kernel_sgemm_diag_left_2_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
88 else // if(m-ii==3)
89 kernel_sgemm_diag_left_3_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
90 }
91
92 return;
93
94 }
95
96
97
98// dgemm with B diagonal matrix (stored as strvec)
99void sgemm_r_diag_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strvec *sB, int bi, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
100 {
101
102 if(m<=0 | n<=0)
103 return;
104
105 if(ai!=0 | ci!=0 | di!=0)
106 {
107 printf("\nsgemm_r_diag_libstr: feature not implemented yet: ai=%d, ci=%d, di=%d\n", ai, ci, di);
108 exit(1);
109 }
110
111 const int bs = 4;
112
113 int sda = sA->cn;
114 int sdc = sC->cn;
115 int sdd = sD->cn;
116 float *pA = sA->pA + aj*bs;
117 float *dB = sB->pa + bi;
118 float *pC = sC->pA + cj*bs;
119 float *pD = sD->pA + dj*bs;
120
121 int ii;
122
123 ii = 0;
124 if(beta==0.0)
125 {
126 for( ; ii<n-3; ii+=4)
127 {
128 kernel_sgemm_diag_right_4_a0_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &pD[ii*bs], sdd);
129 }
130 }
131 else
132 {
133 for( ; ii<n-3; ii+=4)
134 {
135 kernel_sgemm_diag_right_4_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
136 }
137 }
138 if(n-ii>0)
139 {
140 if(n-ii==1)
141 kernel_sgemm_diag_right_1_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
142 else if(n-ii==2)
143 kernel_sgemm_diag_right_2_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
144 else // if(n-ii==3)
145 kernel_sgemm_diag_right_3_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
146 }
147 return;
148
149 }
150
151
152
153#else
154
155#error : wrong LA choice
156
157#endif
158
159
160
161