blob: 1d2378815138c7af9661808336c569e96f58b3ed [file] [log] [blame]
Austin Schuh9049e202022-02-20 17:34:16 -08001#include "polish.h"
2#include "lin_alg.h"
3#include "util.h"
4#include "auxil.h"
5#include "lin_sys.h"
6#include "kkt.h"
7#include "proj.h"
Austin Schuhd9e9dea2022-02-20 19:54:42 -08008#include "osqp_error.h"
Austin Schuh9049e202022-02-20 17:34:16 -08009
10/**
11 * Form reduced matrix A that contains only rows that are active at the
12 * solution.
13 * Ared = vstack[Alow, Aupp]
14 * Active constraints are guessed from the primal and dual solution returned by
15 * the ADMM.
16 * @param work Workspace
17 * @return Number of rows in Ared, negative if error
18 */
19static c_int form_Ared(OSQPWorkspace *work) {
20 c_int j, ptr;
21 c_int Ared_nnz = 0;
22
23 // Initialize counters for active constraints
24 work->pol->n_low = 0;
25 work->pol->n_upp = 0;
26
27 /* Guess which linear constraints are lower-active, upper-active and free
28 * A_to_Alow[j] = -1 (if j-th row of A is not inserted in Alow)
29 * A_to_Alow[j] = i (if j-th row of A is inserted at i-th row of Alow)
30 * Aupp is formed in the equivalent way.
31 * Ared is formed by stacking vertically Alow and Aupp.
32 */
33 for (j = 0; j < work->data->m; j++) {
34 if (work->z[j] - work->data->l[j] < -work->y[j]) { // lower-active
35 work->pol->Alow_to_A[work->pol->n_low] = j;
36 work->pol->A_to_Alow[j] = work->pol->n_low++;
37 } else {
38 work->pol->A_to_Alow[j] = -1;
39 }
40 }
41
42 for (j = 0; j < work->data->m; j++) {
43 if (work->data->u[j] - work->z[j] < work->y[j]) { // upper-active
44 work->pol->Aupp_to_A[work->pol->n_upp] = j;
45 work->pol->A_to_Aupp[j] = work->pol->n_upp++;
46 } else {
47 work->pol->A_to_Aupp[j] = -1;
48 }
49 }
50
51 // Check if there are no active constraints
52 if (work->pol->n_low + work->pol->n_upp == 0) {
53 // Form empty Ared
54 work->pol->Ared = csc_spalloc(0, work->data->n, 0, 1, 0);
55 if (!(work->pol->Ared)) return -1;
56 int_vec_set_scalar(work->pol->Ared->p, 0, work->data->n + 1);
57 return 0; // mred = 0
58 }
59
60 // Count number of elements in Ared
61 for (j = 0; j < work->data->A->p[work->data->A->n]; j++) {
62 if ((work->pol->A_to_Alow[work->data->A->i[j]] != -1) ||
63 (work->pol->A_to_Aupp[work->data->A->i[j]] != -1)) Ared_nnz++;
64 }
65
66 // Form Ared
67 // Ared = vstack[Alow, Aupp]
68 work->pol->Ared = csc_spalloc(work->pol->n_low + work->pol->n_upp,
69 work->data->n, Ared_nnz, 1, 0);
70 if (!(work->pol->Ared)) return -1;
71 Ared_nnz = 0; // counter
72
73 for (j = 0; j < work->data->n; j++) { // Cycle over columns of A
74 work->pol->Ared->p[j] = Ared_nnz;
75
76 for (ptr = work->data->A->p[j]; ptr < work->data->A->p[j + 1]; ptr++) {
77 // Cycle over elements in j-th column
78 if (work->pol->A_to_Alow[work->data->A->i[ptr]] != -1) {
79 // Lower-active rows of A
80 work->pol->Ared->i[Ared_nnz] =
81 work->pol->A_to_Alow[work->data->A->i[ptr]];
82 work->pol->Ared->x[Ared_nnz++] = work->data->A->x[ptr];
83 } else if (work->pol->A_to_Aupp[work->data->A->i[ptr]] != -1) {
84 // Upper-active rows of A
85 work->pol->Ared->i[Ared_nnz] = work->pol->A_to_Aupp[work->data->A->i[ptr]] \
86 + work->pol->n_low;
87 work->pol->Ared->x[Ared_nnz++] = work->data->A->x[ptr];
88 }
89 }
90 }
91
92 // Update the last element in Ared->p
93 work->pol->Ared->p[work->data->n] = Ared_nnz;
94
95 // Return number of rows in Ared
96 return work->pol->n_low + work->pol->n_upp;
97}
98
99/**
100 * Form reduced right-hand side rhs_red = vstack[-q, l_low, u_upp]
101 * @param work Workspace
102 * @param rhs right-hand-side
103 * @return reduced rhs
104 */
105static void form_rhs_red(OSQPWorkspace *work, c_float *rhs) {
106 c_int j;
107
108 // Form the rhs of the reduced KKT linear system
109 for (j = 0; j < work->data->n; j++) { // -q
110 rhs[j] = -work->data->q[j];
111 }
112
113 for (j = 0; j < work->pol->n_low; j++) { // l_low
114 rhs[work->data->n + j] = work->data->l[work->pol->Alow_to_A[j]];
115 }
116
117 for (j = 0; j < work->pol->n_upp; j++) { // u_upp
118 rhs[work->data->n + work->pol->n_low + j] =
119 work->data->u[work->pol->Aupp_to_A[j]];
120 }
121}
122
123/**
124 * Perform iterative refinement on the polished solution:
125 * (repeat)
126 * 1. (K + dK) * dz = b - K*z
127 * 2. z <- z + dz
128 * @param work Solver workspace
129 * @param p Private variable for solving linear system
130 * @param z Initial z value
131 * @param b RHS of the linear system
132 * @return Exitflag
133 */
134static c_int iterative_refinement(OSQPWorkspace *work,
135 LinSysSolver *p,
136 c_float *z,
137 c_float *b) {
138 c_int i, j, n;
139 c_float *rhs;
140
141 if (work->settings->polish_refine_iter > 0) {
142
143 // Assign dimension n
144 n = work->data->n + work->pol->Ared->m;
145
146 // Allocate rhs vector
147 rhs = (c_float *)c_malloc(sizeof(c_float) * n);
148
149 if (!rhs) {
150 return osqp_error(OSQP_MEM_ALLOC_ERROR);
151 } else {
152 for (i = 0; i < work->settings->polish_refine_iter; i++) {
153 // Form the RHS for the iterative refinement: b - K*z
154 prea_vec_copy(b, rhs, n);
155
156 // Upper Part: R^{n}
157 // -= Px (upper triang)
158 mat_vec(work->data->P, z, rhs, -1);
159
160 // -= Px (lower triang)
161 mat_tpose_vec(work->data->P, z, rhs, -1, 1);
162
163 // -= Ared'*y_red
164 mat_tpose_vec(work->pol->Ared, z + work->data->n, rhs, -1, 0);
165
166 // Lower Part: R^{m}
167 mat_vec(work->pol->Ared, z, rhs + work->data->n, -1);
168
169 // Solve linear system. Store solution in rhs
170 p->solve(p, rhs);
171
172 // Update solution
173 for (j = 0; j < n; j++) {
174 z[j] += rhs[j];
175 }
176 }
177 }
178 if (rhs) c_free(rhs);
179 }
180 return 0;
181}
182
183/**
184 * Compute dual variable y from yred
185 * @param work Workspace
186 * @param yred Dual variables associated to active constraints
187 */
188static void get_ypol_from_yred(OSQPWorkspace *work, c_float *yred) {
189 c_int j;
190
191 // If there are no active constraints
192 if (work->pol->n_low + work->pol->n_upp == 0) {
193 vec_set_scalar(work->pol->y, 0., work->data->m);
194 return;
195 }
196
197 // NB: yred = vstack[ylow, yupp]
198 for (j = 0; j < work->data->m; j++) {
199 if (work->pol->A_to_Alow[j] != -1) {
200 // lower-active
201 work->pol->y[j] = yred[work->pol->A_to_Alow[j]];
202 } else if (work->pol->A_to_Aupp[j] != -1) {
203 // upper-active
204 work->pol->y[j] = yred[work->pol->A_to_Aupp[j] + work->pol->n_low];
205 } else {
206 // inactive
207 work->pol->y[j] = 0.0;
208 }
209 }
210}
211
212c_int polish(OSQPWorkspace *work) {
213 c_int mred, polish_successful, exitflag;
214 c_float *rhs_red;
215 LinSysSolver *plsh;
216 c_float *pol_sol; // Polished solution
217
218#ifdef PROFILING
219 osqp_tic(work->timer); // Start timer
220#endif /* ifdef PROFILING */
221
222 // Form Ared by assuming the active constraints and store in work->pol->Ared
223 mred = form_Ared(work);
224 if (mred < 0) { // work->pol->red = OSQP_NULL
225 // Polishing failed
226 work->info->status_polish = -1;
227
228 return -1;
229 }
230
231 // Form and factorize reduced KKT
232 exitflag = init_linsys_solver(&plsh, work->data->P, work->pol->Ared,
233 work->settings->delta, OSQP_NULL,
234 work->settings->linsys_solver, 1);
235
236 if (exitflag) {
237 // Polishing failed
238 work->info->status_polish = -1;
239
240 // Memory clean-up
241 if (work->pol->Ared) csc_spfree(work->pol->Ared);
242
243 return 1;
244 }
245
246 // Form reduced right-hand side rhs_red
247 rhs_red = c_malloc(sizeof(c_float) * (work->data->n + mred));
248 if (!rhs_red) {
249 // Polishing failed
250 work->info->status_polish = -1;
251
252 // Memory clean-up
253 csc_spfree(work->pol->Ared);
254
255 return -1;
256 }
257 form_rhs_red(work, rhs_red);
258
259 pol_sol = vec_copy(rhs_red, work->data->n + mred);
260 if (!pol_sol) {
261 // Polishing failed
262 work->info->status_polish = -1;
263
264 // Memory clean-up
265 csc_spfree(work->pol->Ared);
266 c_free(rhs_red);
267
268 return -1;
269 }
270
271 // Solve the reduced KKT system
272 plsh->solve(plsh, pol_sol);
273
274 // Perform iterative refinement to compensate for the regularization error
275 exitflag = iterative_refinement(work, plsh, pol_sol, rhs_red);
276
277 if (exitflag) {
278 // Polishing failed
279 work->info->status_polish = -1;
280
281 // Memory clean-up
282 csc_spfree(work->pol->Ared);
283 c_free(rhs_red);
284 c_free(pol_sol);
285
286 return -1;
287 }
288
289 // Store the polished solution (x,z,y)
290 prea_vec_copy(pol_sol, work->pol->x, work->data->n); // pol->x
291 mat_vec(work->data->A, work->pol->x, work->pol->z, 0); // pol->z
292 get_ypol_from_yred(work, pol_sol + work->data->n); // pol->y
293
294 // Ensure (z,y) satisfies normal cone constraint
295 project_normalcone(work, work->pol->z, work->pol->y);
296
297 // Compute primal and dual residuals at the polished solution
298 update_info(work, 0, 1, 1);
299
300 // Check if polish was successful
301 polish_successful = (work->pol->pri_res < work->info->pri_res &&
302 work->pol->dua_res < work->info->dua_res) || // Residuals
303 // are
304 // reduced
305 (work->pol->pri_res < work->info->pri_res &&
306 work->info->dua_res < 1e-10) || // Dual
307 // residual
308 // already
309 // tiny
310 (work->pol->dua_res < work->info->dua_res &&
311 work->info->pri_res < 1e-10); // Primal
312 // residual
313 // already
314 // tiny
315
316 if (polish_successful) {
317 // Update solver information
318 work->info->obj_val = work->pol->obj_val;
319 work->info->pri_res = work->pol->pri_res;
320 work->info->dua_res = work->pol->dua_res;
321 work->info->status_polish = 1;
322
323 // Update (x, z, y) in ADMM iterations
324 // NB: z needed for warm starting
325 prea_vec_copy(work->pol->x, work->x, work->data->n);
326 prea_vec_copy(work->pol->z, work->z, work->data->m);
327 prea_vec_copy(work->pol->y, work->y, work->data->m);
328
329 // Print summary
330#ifdef PRINTING
331
332 if (work->settings->verbose) print_polish(work);
333#endif /* ifdef PRINTING */
334 } else { // Polishing failed
335 work->info->status_polish = -1;
336
337 // TODO: Try to find a better solution on the line connecting ADMM
338 // and polished solution
339 }
340
341 // Memory clean-up
342 plsh->free(plsh);
343
344 // Checks that they are not NULL are already performed earlier
345 csc_spfree(work->pol->Ared);
346 c_free(rhs_red);
347 c_free(pol_sol);
348
349 return 0;
350}