1 |
/* ndfun.c: MEX (MATLAB) file to implement functions that treat |
---|
2 |
multi-dimensional arrays as "pages" of 2D matrices. |
---|
3 |
|
---|
4 |
This allows you do to, for example, |
---|
5 |
C = ndfun('mult', A, B), |
---|
6 |
which is equivalent to |
---|
7 |
for i = 1:100 |
---|
8 |
C(:,:,i) = A(:,:,i) * B(:,:,i); |
---|
9 |
end |
---|
10 |
|
---|
11 |
except it is more flexible, since it does the same for any number |
---|
12 |
of dimensions. |
---|
13 |
|
---|
14 |
It also automatically reuses 2D matrices in either position, as in: |
---|
15 |
for i = 1:100 |
---|
16 |
C(:,:,i) = A * B(:,:,i); |
---|
17 |
end |
---|
18 |
|
---|
19 |
Supported operations are now multiplication, inverses, and square |
---|
20 |
matrix backslash. |
---|
21 |
|
---|
22 |
Debating including a "fast" option that skips the singularity |
---|
23 |
check. For 100x100 inverse, this would save 15%. Opinions? |
---|
24 |
|
---|
25 |
Author: Peter Boettcher <boettcher@ll.mit.edu> |
---|
26 |
Copyright 2002, Peter Boettcher |
---|
27 |
*/ |
---|
28 |
|
---|
29 |
|
---|
30 |
/* $Id: ndfun.c,v 1.6 2002/03/28 21:45:23 pwb Exp $ */ |
---|
31 |
|
---|
32 |
#ifndef lint |
---|
33 |
static char vcid[] = "$Id: ndfun.c,v 1.6 2002/03/28 21:45:23 pwb Exp $"; |
---|
34 |
#endif /* lint */ |
---|
35 |
|
---|
36 |
#include "mex.h" |
---|
37 |
#include <string.h> |
---|
38 |
#include <math.h> |
---|
39 |
|
---|
40 |
double compute_norm(double *A, int m, int n); |
---|
41 |
void compute_lu(double *X, int m, int *ipivot, double *work, int *iwork, int check_singular); |
---|
42 |
void blas_return_check(int info, const char *blasfcn); |
---|
43 |
|
---|
44 |
/* Does anyone else NOT mangle underbars onto BLAS function names? |
---|
45 |
Put 'em here. */ |
---|
46 |
#if defined(__OS2__) || defined(__WINDOWS__) || defined(WIN32) |
---|
47 |
#define BLASCALL(f) f |
---|
48 |
#else |
---|
49 |
#define BLASCALL(f) f ## _ |
---|
50 |
#endif |
---|
51 |
|
---|
52 |
|
---|
53 |
/* BLAS/LAPACK Function prototypes added 03/28/02 Aj */ |
---|
54 |
void BLASCALL(dgemm)(const char *TRANSA, const char *TRANSB, |
---|
55 |
const int *M, const int *N, const int *K, const double *ALPHA, |
---|
56 |
const double A[], const int *LDA, const double *B, const int *LDB, |
---|
57 |
const double *BETA, double C[], const int *LDC); |
---|
58 |
void BLASCALL(dgetrs)(const char *TRANS, const int *N, const int *NRHS, |
---|
59 |
const double A[], const int *LDA, const int IPIV[], double B[], |
---|
60 |
const int *LDB, int *INFO); |
---|
61 |
void BLASCALL(dgetri)(const int *N, double A[], const int *LDA, const int IPIV[], |
---|
62 |
double WORK[], const int LWORK[], int *INFO); |
---|
63 |
void BLASCALL(dgetrf)(const int *M, const int *N, double A[], const int *LDA, |
---|
64 |
int IPIV[], int *INFO); |
---|
65 |
void BLASCALL(dgecon)(const char *NORM, const int *N, const double A[], |
---|
66 |
const int *LDA, const double *ANORM, double *RCOND, double WORK[], |
---|
67 |
int IWORK[], int *INFO); |
---|
68 |
|
---|
69 |
/* Typedefs */ |
---|
70 |
typedef enum {COMMAND_INVALID=0, COMMAND_MULT, COMMAND_INV, |
---|
71 |
COMMAND_BACKSLASH, COMMAND_VERSION} commandcode_t; |
---|
72 |
|
---|
73 |
#define CHECK_SQUARE_A 1 |
---|
74 |
#define CHECK_SQUARE_B 2 |
---|
75 |
#define CHECK_AT_LEAST_2D_A 4 |
---|
76 |
#define CHECK_AT_LEAST_2D_B 8 |
---|
77 |
|
---|
78 |
struct ndcommand_s { |
---|
79 |
char *cmdstr; |
---|
80 |
commandcode_t commandcode; |
---|
81 |
int num_args; |
---|
82 |
int check; |
---|
83 |
} ndcommand_list[] = {{"mult", COMMAND_MULT, 2, 0}, |
---|
84 |
{"inv", COMMAND_INV, 1, CHECK_SQUARE_A | CHECK_AT_LEAST_2D_A}, |
---|
85 |
{"backslash", COMMAND_BACKSLASH, 2, CHECK_SQUARE_A | CHECK_AT_LEAST_2D_A}, |
---|
86 |
{"version", COMMAND_VERSION, 0, 0}, |
---|
87 |
{NULL, COMMAND_INVALID, 0}}; |
---|
88 |
|
---|
89 |
double eps; |
---|
90 |
|
---|
91 |
struct ndcommand_s *get_command(const mxArray *mxCMD) |
---|
92 |
{ |
---|
93 |
char *commandstr; |
---|
94 |
int i=0; |
---|
95 |
|
---|
96 |
if(mxGetClassID(mxCMD) != mxCHAR_CLASS) |
---|
97 |
mexErrMsgTxt("First argument must be the command to use"); |
---|
98 |
commandstr = mxArrayToString(mxCMD); |
---|
99 |
|
---|
100 |
while(ndcommand_list[i].cmdstr) { |
---|
101 |
if (strcmp(commandstr, ndcommand_list[i].cmdstr) == 0) { |
---|
102 |
mxFree(commandstr); |
---|
103 |
return(&ndcommand_list[i]); |
---|
104 |
} |
---|
105 |
i++; |
---|
106 |
} |
---|
107 |
|
---|
108 |
mxFree(commandstr); |
---|
109 |
mexErrMsgTxt("Unknown command"); |
---|
110 |
|
---|
111 |
return(NULL); |
---|
112 |
} |
---|
113 |
|
---|
114 |
int page_dim_check(int numDimsA, int numDimsB, const int *dimsA, const int *dimsB, int dimcheck) |
---|
115 |
{ |
---|
116 |
int i, numPages=1; |
---|
117 |
|
---|
118 |
/* OK, valid possibilities are: |
---|
119 |
-Fully matching N-D arrays |
---|
120 |
-One 2D (or less) array and one arbitrary N-D array */ |
---|
121 |
|
---|
122 |
if((numDimsA <= 2) || (numDimsB <= 2)) { |
---|
123 |
/* repeated_arg = 1; */ |
---|
124 |
} else { |
---|
125 |
if(numDimsA != numDimsB) |
---|
126 |
mexErrMsgTxt("Invalid dimensions"); |
---|
127 |
for(i=2; i<numDimsA; i++) { |
---|
128 |
if(dimsA[i] != dimsB[i]) |
---|
129 |
mexErrMsgTxt("Dimensions after 2 must match"); |
---|
130 |
numPages *= dimsA[i]; |
---|
131 |
} |
---|
132 |
} |
---|
133 |
|
---|
134 |
if(dimcheck & CHECK_AT_LEAST_2D_A) { |
---|
135 |
if(numDimsA < 2) |
---|
136 |
mexErrMsgTxt("A must be at least 2D!"); |
---|
137 |
} |
---|
138 |
if(dimcheck & CHECK_AT_LEAST_2D_B) { |
---|
139 |
if(numDimsB < 2) |
---|
140 |
mexErrMsgTxt("B must be at least 2D!"); |
---|
141 |
} |
---|
142 |
|
---|
143 |
if(dimcheck & CHECK_SQUARE_A) { |
---|
144 |
if(dimsA[0] != dimsA[1]) |
---|
145 |
mexErrMsgTxt("A must be square in first 2 dimensions"); |
---|
146 |
} |
---|
147 |
if(dimcheck & CHECK_SQUARE_B) { |
---|
148 |
if(dimsB[0] != dimsB[1]) |
---|
149 |
mexErrMsgTxt("A must be square in first 2 dimensions"); |
---|
150 |
} |
---|
151 |
|
---|
152 |
return(0); |
---|
153 |
} |
---|
154 |
|
---|
155 |
|
---|
156 |
|
---|
157 |
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) |
---|
158 |
{ |
---|
159 |
const int *dimsA=NULL, *dimsB=NULL, *dimsptr; |
---|
160 |
int *dimsC; |
---|
161 |
int numDimsA=0, numDimsB=0, numDimsC=0; |
---|
162 |
int m=0, n=0, p=0, i; |
---|
163 |
double *A, *B, *C, one = 1.0, zero = 0.0; |
---|
164 |
int numPages=1; |
---|
165 |
int strideA, strideB, strideC; |
---|
166 |
struct ndcommand_s *command; |
---|
167 |
const mxArray *mxA=NULL, *mxB=NULL; |
---|
168 |
int *ipivot, info, *iwork; |
---|
169 |
double *work, *scratchA; |
---|
170 |
|
---|
171 |
eps = mxGetEps(); |
---|
172 |
|
---|
173 |
if(nrhs < 1) |
---|
174 |
mexErrMsgTxt("Not enough arguments"); |
---|
175 |
|
---|
176 |
/* Figure which command was chosen */ |
---|
177 |
command = get_command(prhs[0]); |
---|
178 |
|
---|
179 |
/* Set up some variables for the 2 and 1 argument cases */ |
---|
180 |
if(command->num_args == 2) { |
---|
181 |
if(nrhs != 3) |
---|
182 |
mexErrMsgTxt("Two arguments required"); |
---|
183 |
|
---|
184 |
mxA = prhs[1]; |
---|
185 |
mxB = prhs[2]; |
---|
186 |
|
---|
187 |
numDimsA = mxGetNumberOfDimensions(mxA); |
---|
188 |
dimsA = mxGetDimensions(mxA); |
---|
189 |
|
---|
190 |
numDimsB = mxGetNumberOfDimensions(mxB); |
---|
191 |
dimsB = mxGetDimensions(mxB); |
---|
192 |
} else if(command->num_args == 1) { |
---|
193 |
if(nrhs != 2) |
---|
194 |
mexErrMsgTxt("One argument required"); |
---|
195 |
|
---|
196 |
mxA = prhs[1]; |
---|
197 |
numDimsA = mxGetNumberOfDimensions(mxA); |
---|
198 |
dimsA = mxGetDimensions(mxA); |
---|
199 |
} |
---|
200 |
|
---|
201 |
/* Be sure dimensions agree in the necessary ways. check is a |
---|
202 |
bitmask of "necessary" checks to perform, which depends on the |
---|
203 |
command chosen */ |
---|
204 |
page_dim_check(numDimsA, numDimsB, dimsA, dimsB, command->check); |
---|
205 |
|
---|
206 |
switch(command->commandcode) { |
---|
207 |
case COMMAND_VERSION: |
---|
208 |
mexPrintf("NDFUN MEX file\nCopyright 2002 Peter Boettcher\n%s\n", |
---|
209 |
"$Revision: 1.6 $"); |
---|
210 |
break; |
---|
211 |
case COMMAND_MULT: |
---|
212 |
/****************************************** |
---|
213 |
* MULTIPLY |
---|
214 |
******************************************/ |
---|
215 |
|
---|
216 |
if(dimsA[1] != dimsB[0]) |
---|
217 |
mexErrMsgTxt("Inner dimensions (first 2) don't match"); |
---|
218 |
|
---|
219 |
m = dimsA[0]; |
---|
220 |
n = dimsB[1]; |
---|
221 |
p = dimsA[1]; |
---|
222 |
strideC = m*n; |
---|
223 |
|
---|
224 |
strideA = m*p; |
---|
225 |
strideB = p*n; |
---|
226 |
dimsptr = dimsA; |
---|
227 |
numDimsC = numDimsA; |
---|
228 |
|
---|
229 |
if(numDimsA != numDimsB) { |
---|
230 |
if(numDimsA < numDimsB) { |
---|
231 |
strideA = 0; |
---|
232 |
numDimsC = numDimsB; |
---|
233 |
dimsptr = dimsB; |
---|
234 |
} else { |
---|
235 |
strideB = 0; |
---|
236 |
} |
---|
237 |
} |
---|
238 |
|
---|
239 |
for(i=2; i<numDimsC; i++) |
---|
240 |
numPages *= dimsptr[i]; |
---|
241 |
|
---|
242 |
dimsC = (int *)mxMalloc(numDimsC*sizeof(int)); |
---|
243 |
dimsC[0] = m; |
---|
244 |
dimsC[1] = n; |
---|
245 |
for(i=2; i<numDimsC; i++) |
---|
246 |
dimsC[i] = dimsptr[i]; |
---|
247 |
|
---|
248 |
plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL); |
---|
249 |
C = mxGetPr(plhs[0]); |
---|
250 |
A = mxGetPr(mxA); |
---|
251 |
B = mxGetPr(mxB); |
---|
252 |
|
---|
253 |
for(i=0; i<numPages; i++) { |
---|
254 |
BLASCALL(dgemm)("N", "N", &m, &n, &p, &one, A + i*strideA, &m, B + i*strideB, |
---|
255 |
&p, &zero, C + i*strideC, &m); |
---|
256 |
} |
---|
257 |
|
---|
258 |
mxFree(dimsC); |
---|
259 |
break; |
---|
260 |
|
---|
261 |
case COMMAND_BACKSLASH: |
---|
262 |
/****************************************** |
---|
263 |
* BACKSLASH |
---|
264 |
******************************************/ |
---|
265 |
|
---|
266 |
if(dimsA[0] != dimsB[0]) |
---|
267 |
mexErrMsgTxt("First dimensions must match"); |
---|
268 |
|
---|
269 |
m = dimsA[0]; |
---|
270 |
n = dimsA[1]; |
---|
271 |
p = dimsB[1]; |
---|
272 |
|
---|
273 |
strideC = n*p; |
---|
274 |
strideA = m*n; |
---|
275 |
strideB = m*p; |
---|
276 |
dimsptr = dimsA; |
---|
277 |
numDimsC = numDimsA; |
---|
278 |
|
---|
279 |
if(numDimsA != numDimsB) { |
---|
280 |
if(numDimsA < numDimsB) { |
---|
281 |
strideA = 0; |
---|
282 |
numDimsC = numDimsB; |
---|
283 |
dimsptr = dimsB; |
---|
284 |
} else { |
---|
285 |
strideB = 0; |
---|
286 |
} |
---|
287 |
} |
---|
288 |
for(i=2; i<numDimsC; i++) |
---|
289 |
numPages *= dimsptr[i]; |
---|
290 |
dimsC = (int *)mxMalloc(numDimsC*sizeof(int)); |
---|
291 |
dimsC[0] = n; |
---|
292 |
dimsC[1] = p; |
---|
293 |
for(i=2; i<numDimsC; i++) |
---|
294 |
dimsC[i] = dimsptr[i]; |
---|
295 |
|
---|
296 |
plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL); |
---|
297 |
|
---|
298 |
C = mxGetPr(plhs[0]); |
---|
299 |
A = mxGetPr(mxA); |
---|
300 |
B = mxGetPr(mxB); |
---|
301 |
|
---|
302 |
ipivot = (int *)mxMalloc(m*sizeof(int)); |
---|
303 |
iwork = (int *)mxMalloc(m*sizeof(int)); |
---|
304 |
work = (double *)mxMalloc(m*m*sizeof(double)); |
---|
305 |
scratchA = (double *)mxMalloc(m*n*sizeof(double)); |
---|
306 |
|
---|
307 |
|
---|
308 |
if(numDimsA < numDimsB) { |
---|
309 |
/* Single A, multiple B. That means do one LU on A, and multiple solves */ |
---|
310 |
/* Save memory by doing it this way... that way we need only a m*n temp array */ |
---|
311 |
memcpy(scratchA, A, mxGetNumberOfElements(mxA)*sizeof(double)); |
---|
312 |
memcpy(C, B, m*p*numPages*sizeof(double)); |
---|
313 |
compute_lu(scratchA, m, ipivot, work, iwork, 1); |
---|
314 |
|
---|
315 |
/* Loop over pages of B and compute */ |
---|
316 |
for(i=0; i<numPages; i++) { |
---|
317 |
BLASCALL(dgetrs)("N", &m, &p, scratchA, &m, ipivot, C + i*strideC, &m, &info); |
---|
318 |
blas_return_check(info, "DGETRS"); |
---|
319 |
} |
---|
320 |
} else { |
---|
321 |
/* Multiple A. Do the LU each step through */ |
---|
322 |
for(i=0; i<numPages; i++) { |
---|
323 |
memcpy(scratchA, A + i*strideA, m*n*sizeof(double)); |
---|
324 |
compute_lu(scratchA, m, ipivot, work, iwork, 1); |
---|
325 |
|
---|
326 |
/* Compute */ |
---|
327 |
memcpy(C+i*strideC, B+i*strideB, m*p*sizeof(double)); |
---|
328 |
BLASCALL(dgetrs)("N", &m, &p, scratchA, &m, ipivot, C + i*strideC, &m, &info); |
---|
329 |
blas_return_check(info, "DGETRS"); |
---|
330 |
} |
---|
331 |
} |
---|
332 |
|
---|
333 |
mxFree(iwork); |
---|
334 |
mxFree(scratchA); |
---|
335 |
mxFree(dimsC); |
---|
336 |
mxFree(ipivot); |
---|
337 |
mxFree(work); |
---|
338 |
break; |
---|
339 |
case COMMAND_INV: |
---|
340 |
/****************************************** |
---|
341 |
* INVERSE |
---|
342 |
******************************************/ |
---|
343 |
m = dimsA[0]; |
---|
344 |
n = dimsA[1]; |
---|
345 |
ipivot = (int *)mxMalloc(m*sizeof(int)); |
---|
346 |
work = (double *)mxMalloc(m*m*sizeof(double)); |
---|
347 |
iwork = (int *)mxMalloc(m*sizeof(int)); |
---|
348 |
|
---|
349 |
plhs[0] = mxDuplicateArray(mxA); |
---|
350 |
C = mxGetPr(plhs[0]); |
---|
351 |
strideC = n*m; |
---|
352 |
|
---|
353 |
for(i=2; i<numDimsA; i++) |
---|
354 |
numPages *= dimsA[i]; |
---|
355 |
|
---|
356 |
for(i=0; i<numPages; i++) { |
---|
357 |
compute_lu(C + i*strideC, m, ipivot, work, iwork, 1); |
---|
358 |
|
---|
359 |
BLASCALL(dgetri)(&n, C + i*strideC, &m, ipivot, work, &n, &info ); |
---|
360 |
blas_return_check(info, "DGETRI"); |
---|
361 |
/* if(info>0) mexWarnMsgTxt("Matrix is singular to working precision"); */ |
---|
362 |
} |
---|
363 |
|
---|
364 |
mxFree(ipivot); |
---|
365 |
mxFree(work); |
---|
366 |
mxFree(iwork); |
---|
367 |
break; |
---|
368 |
|
---|
369 |
default: |
---|
370 |
mexErrMsgTxt("Should never get here"); |
---|
371 |
} |
---|
372 |
|
---|
373 |
} |
---|
374 |
|
---|
375 |
/* Wrapper function for LU decomposition. Optionally checks |
---|
376 |
singularity of result. For efficiency, pass in the scratch |
---|
377 |
buffers. Result appears in-place. See BLAS docs on DGETRF and |
---|
378 |
DGECON for required scratch buffer sizes. */ |
---|
379 |
void compute_lu(double *X, int m, int *ipivot, double *work, int *iwork, int check_singular) |
---|
380 |
{ |
---|
381 |
double anorm, rcond; |
---|
382 |
int info; |
---|
383 |
char errmsg[255]; |
---|
384 |
|
---|
385 |
anorm = compute_norm(X, m, m); |
---|
386 |
BLASCALL(dgetrf)(&m, &m, X, &m, ipivot, &info); /* LU call */ |
---|
387 |
blas_return_check(info, "DGETRF"); |
---|
388 |
|
---|
389 |
if(check_singular) { |
---|
390 |
/* Check singularity */ |
---|
391 |
if(info>0) |
---|
392 |
mexWarnMsgTxt("Matrix is singular to working precision"); |
---|
393 |
else { |
---|
394 |
BLASCALL(dgecon)("1", &m, X, &m, &anorm, &rcond, work, iwork, &info); |
---|
395 |
blas_return_check(info, "DGECON"); |
---|
396 |
|
---|
397 |
if(rcond < eps) { |
---|
398 |
sprintf(errmsg, "%s\n %s RCOND = %e.", |
---|
399 |
"Matrix is close to singular or badly scaled.", |
---|
400 |
"Results may be inaccurate.", rcond); |
---|
401 |
mexWarnMsgTxt(errmsg); |
---|
402 |
} |
---|
403 |
} |
---|
404 |
} |
---|
405 |
|
---|
406 |
} |
---|
407 |
|
---|
408 |
/* Check the INFO parameter of a BLAS call and error with a useful message if negative */ |
---|
409 |
void blas_return_check(int info, const char *blasfcn) |
---|
410 |
{ |
---|
411 |
char errmsg[255]; |
---|
412 |
|
---|
413 |
if(info < 0) { |
---|
414 |
sprintf(errmsg, "Internal error: Illegal %s call, problem in arg %i", blasfcn, |
---|
415 |
abs(info)); |
---|
416 |
mexErrMsgTxt(errmsg); |
---|
417 |
} |
---|
418 |
} |
---|
419 |
|
---|
420 |
double compute_norm(double *A, int m, int n) |
---|
421 |
{ |
---|
422 |
int i, j; |
---|
423 |
double sum; |
---|
424 |
double curmax = 0.0; |
---|
425 |
|
---|
426 |
for(j=0; j<n; j++) { |
---|
427 |
sum = 0; |
---|
428 |
for(i=0; i<m; i++) { |
---|
429 |
sum += fabs(A[m*j + i]); |
---|
430 |
} |
---|
431 |
if(sum > curmax) |
---|
432 |
curmax = sum; |
---|
433 |
} |
---|
434 |
return(curmax); |
---|
435 |
} |
---|