blob: 8469345b4c62c3adf4da134088f779c9b471f772 [file] [log] [blame]
Austin Schuh9a24b372018-01-28 16:12:29 -08001/**************************************************************************************************
2* *
3* This file is part of BLASFEO. *
4* *
5* BLASFEO -- BLAS For Embedded Optimization. *
6* Copyright (C) 2016-2017 by Gianluca Frison. *
7* Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8* All rights reserved. *
9* *
10* HPMPC is free software; you can redistribute it and/or *
11* modify it under the terms of the GNU Lesser General Public *
12* License as published by the Free Software Foundation; either *
13* version 2.1 of the License, or (at your option) any later version. *
14* *
15* HPMPC is distributed in the hope that it will be useful, *
16* but WITHOUT ANY WARRANTY; without even the implied warranty of *
17* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
18* See the GNU Lesser General Public License for more details. *
19* *
20* You should have received a copy of the GNU Lesser General Public *
21* License along with HPMPC; if not, write to the Free Software *
22* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA *
23* *
24* Author: Gianluca Frison, giaf (at) dtu.dk *
25* gianluca.frison (at) imtek.uni-freiburg.de *
26* *
27**************************************************************************************************/
28
29#include <stdlib.h>
30#include <stdio.h>
31
32#include "../include/blasfeo_common.h"
33#include "../include/blasfeo_s_kernel.h"
34
35
36
37
38#if defined(LA_HIGH_PERFORMANCE)
39
40
41
42// dgemm with B diagonal matrix (stored as strvec)
43void sgemm_r_diag_libstr(int m, int n, float alpha, struct s_strmat *sA, int ai, int aj, struct s_strvec *sB, int bi, float beta, struct s_strmat *sC, int ci, int cj, struct s_strmat *sD, int di, int dj)
44 {
45
46 if(m<=0 | n<=0)
47 return;
48
49 if(ai!=0 | ci!=0 | di!=0)
50 {
51 printf("\nsgemm_r_diag_libstr: feature not implemented yet: ai=%d, ci=%d, di=%d\n", ai, ci, di);
52 exit(1);
53 }
54
55 const int bs = 8;
56
57 int sda = sA->cn;
58 int sdc = sC->cn;
59 int sdd = sD->cn;
60 float *pA = sA->pA + aj*bs;
61 float *dB = sB->pa + bi;
62 float *pC = sC->pA + cj*bs;
63 float *pD = sD->pA + dj*bs;
64
65 int ii;
66
67 ii = 0;
68 if(beta==0.0)
69 {
70 for( ; ii<n-3; ii+=4)
71 {
72 kernel_sgemm_diag_right_4_a0_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &pD[ii*bs], sdd);
73 }
74 }
75 else
76 {
77 for( ; ii<n-3; ii+=4)
78 {
79 kernel_sgemm_diag_right_4_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
80 }
81 }
82 if(n-ii>0)
83 {
84 if(n-ii==1)
85 kernel_sgemm_diag_right_1_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
86 else if(n-ii==2)
87 kernel_sgemm_diag_right_2_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
88 else // if(n-ii==3)
89 kernel_sgemm_diag_right_3_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
90 }
91 return;
92
93 }
94
95
96
97#else
98
99#error : wrong LA choice
100
101#endif
102
103
104
105