NCCOOS Trac Projects: Top | Web | Platforms | Processing | Viz | Sprints | Sandbox | (Wind)

root/gliderproc/trunk/MATLAB/matutil/ndfun.c

Revision 495 (checked in by cbc, 12 years ago)

Initial import of Stark code.

Line 
1 /* ndfun.c: MEX (MATLAB) file to implement functions that treat
2    multi-dimensional arrays as "pages" of 2D matrices.
3    
4    This allows you do to, for example,
5      C = ndfun('mult', A, B),
6    which is equivalent to
7      for i = 1:100
8        C(:,:,i) = A(:,:,i) * B(:,:,i);
9      end
10
11    except it is more flexible, since it does the same for any number
12    of dimensions.
13
14    It also automatically reuses 2D matrices in either position, as in:
15      for i = 1:100
16        C(:,:,i) = A * B(:,:,i);
17      end
18
19    Supported operations are now multiplication, inverses, and square
20    matrix backslash.
21
22    Debating including a "fast" option that skips the singularity
23    check.  For 100x100 inverse, this would save 15%.  Opinions?
24
25    Author: Peter Boettcher <boettcher@ll.mit.edu>
26    Copyright 2002, Peter Boettcher
27 */
28
29
30 /*      $Id: ndfun.c,v 1.6 2002/03/28 21:45:23 pwb Exp $         */
31
32 #ifndef lint
33 static char vcid[] = "$Id: ndfun.c,v 1.6 2002/03/28 21:45:23 pwb Exp $";
34 #endif /* lint */
35
36 #include "mex.h"
37 #include <string.h>
38 #include <math.h>
39
40 double compute_norm(double *A, int m, int n);
41 void compute_lu(double *X, int m, int *ipivot, double *work, int *iwork, int check_singular);
42 void blas_return_check(int info, const char *blasfcn);
43
44 /* Does anyone else NOT mangle underbars onto BLAS function names?
45    Put 'em here. */
46 #if defined(__OS2__)  || defined(__WINDOWS__) || defined(WIN32)
47 #define BLASCALL(f) f
48 #else
49 #define BLASCALL(f) f ## _
50 #endif
51
52
53 /* BLAS/LAPACK Function prototypes added 03/28/02 Aj */
54 void BLASCALL(dgemm)(const char *TRANSA, const char *TRANSB,
55                      const int *M, const int *N, const int *K, const double *ALPHA,
56                      const double A[], const int *LDA, const double *B, const int *LDB,
57                      const double *BETA, double C[], const int *LDC);
58 void BLASCALL(dgetrs)(const char *TRANS, const int *N, const int *NRHS,
59                       const double A[], const int *LDA, const int IPIV[], double B[],
60                       const int *LDB, int *INFO);
61 void BLASCALL(dgetri)(const int *N, double A[], const int *LDA, const int IPIV[],
62                       double WORK[], const int LWORK[], int *INFO);
63 void BLASCALL(dgetrf)(const int *M, const int *N, double A[], const int *LDA,
64                       int IPIV[], int *INFO);
65 void BLASCALL(dgecon)(const char *NORM, const int *N, const double A[],
66                       const int *LDA, const double *ANORM, double *RCOND, double WORK[],
67                       int IWORK[], int *INFO);
68
69 /* Typedefs */
70 typedef enum {COMMAND_INVALID=0, COMMAND_MULT, COMMAND_INV,
71               COMMAND_BACKSLASH, COMMAND_VERSION} commandcode_t;
72
73 #define CHECK_SQUARE_A 1
74 #define CHECK_SQUARE_B 2
75 #define CHECK_AT_LEAST_2D_A 4
76 #define CHECK_AT_LEAST_2D_B 8
77
78 struct ndcommand_s {
79   char *cmdstr;
80   commandcode_t commandcode;
81   int num_args;
82   int check;
83 } ndcommand_list[] = {{"mult", COMMAND_MULT, 2, 0},
84                       {"inv", COMMAND_INV, 1, CHECK_SQUARE_A | CHECK_AT_LEAST_2D_A},
85                       {"backslash", COMMAND_BACKSLASH, 2, CHECK_SQUARE_A | CHECK_AT_LEAST_2D_A},
86                       {"version", COMMAND_VERSION, 0, 0},
87                       {NULL, COMMAND_INVALID, 0}};
88
89 double eps;
90
91 struct ndcommand_s *get_command(const mxArray *mxCMD)
92 {
93   char *commandstr;
94   int i=0;
95
96   if(mxGetClassID(mxCMD) != mxCHAR_CLASS)
97     mexErrMsgTxt("First argument must be the command to use");
98   commandstr = mxArrayToString(mxCMD);
99  
100   while(ndcommand_list[i].cmdstr) {
101     if (strcmp(commandstr, ndcommand_list[i].cmdstr) == 0) {
102       mxFree(commandstr);
103       return(&ndcommand_list[i]);
104     }
105     i++;
106   }
107
108   mxFree(commandstr);
109   mexErrMsgTxt("Unknown command");
110
111   return(NULL);
112 }
113
114 int page_dim_check(int numDimsA, int numDimsB, const int *dimsA, const int *dimsB, int dimcheck)
115 {
116   int i, numPages=1;
117  
118   /* OK, valid possibilities are:
119      -Fully matching N-D arrays
120      -One 2D (or less) array and one arbitrary N-D array */
121
122   if((numDimsA <= 2) || (numDimsB <= 2)) {
123     /* repeated_arg = 1; */
124   } else {
125     if(numDimsA != numDimsB)
126       mexErrMsgTxt("Invalid dimensions");
127     for(i=2; i<numDimsA; i++) {
128       if(dimsA[i] != dimsB[i])
129         mexErrMsgTxt("Dimensions after 2 must match");
130       numPages *= dimsA[i];
131     }
132   }     
133
134   if(dimcheck & CHECK_AT_LEAST_2D_A) {
135     if(numDimsA < 2)
136       mexErrMsgTxt("A must be at least 2D!");
137   }
138   if(dimcheck & CHECK_AT_LEAST_2D_B) {
139     if(numDimsB < 2)
140       mexErrMsgTxt("B must be at least 2D!");
141   }
142      
143   if(dimcheck & CHECK_SQUARE_A) {
144     if(dimsA[0] != dimsA[1])
145       mexErrMsgTxt("A must be square in first 2 dimensions");
146   }
147   if(dimcheck & CHECK_SQUARE_B) {
148     if(dimsB[0] != dimsB[1])
149       mexErrMsgTxt("A must be square in first 2 dimensions");
150   }
151  
152   return(0);
153 }
154    
155
156
157 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
158 {
159   const int *dimsA=NULL, *dimsB=NULL, *dimsptr;
160   int *dimsC;
161   int numDimsA=0, numDimsB=0, numDimsC=0;
162   int m=0, n=0, p=0, i;
163   double *A, *B, *C, one = 1.0, zero = 0.0;
164   int numPages=1;
165   int strideA, strideB, strideC;
166   struct ndcommand_s *command;
167   const mxArray *mxA=NULL, *mxB=NULL;
168   int *ipivot, info, *iwork;
169   double *work, *scratchA;
170  
171   eps = mxGetEps();
172
173   if(nrhs < 1)
174     mexErrMsgTxt("Not enough arguments");
175
176   /* Figure which command was chosen */
177   command = get_command(prhs[0]);
178
179   /* Set up some variables for the 2 and 1 argument cases */
180   if(command->num_args == 2) {
181     if(nrhs != 3)
182       mexErrMsgTxt("Two arguments required");
183
184     mxA = prhs[1];
185     mxB = prhs[2];
186    
187     numDimsA = mxGetNumberOfDimensions(mxA);
188     dimsA = mxGetDimensions(mxA);
189
190     numDimsB = mxGetNumberOfDimensions(mxB);
191     dimsB = mxGetDimensions(mxB);
192   } else if(command->num_args == 1) {
193     if(nrhs != 2)
194       mexErrMsgTxt("One argument required");
195
196     mxA = prhs[1];
197     numDimsA = mxGetNumberOfDimensions(mxA);
198     dimsA = mxGetDimensions(mxA);
199   }
200
201   /* Be sure dimensions agree in the necessary ways.  check is a
202      bitmask of "necessary" checks to perform, which depends on the
203      command chosen */
204   page_dim_check(numDimsA, numDimsB, dimsA, dimsB, command->check);
205
206   switch(command->commandcode) {
207   case COMMAND_VERSION:
208     mexPrintf("NDFUN MEX file\nCopyright 2002 Peter Boettcher\n%s\n",
209               "$Revision: 1.6 $");
210     break;
211   case COMMAND_MULT:
212     /******************************************
213      * MULTIPLY
214      ******************************************/
215
216     if(dimsA[1] != dimsB[0])
217       mexErrMsgTxt("Inner dimensions (first 2) don't match");
218    
219     m = dimsA[0];
220     n = dimsB[1];
221     p = dimsA[1];
222     strideC = m*n;
223
224     strideA = m*p;
225     strideB = p*n;
226     dimsptr = dimsA;
227     numDimsC = numDimsA;
228
229     if(numDimsA != numDimsB) {
230       if(numDimsA < numDimsB) {
231         strideA = 0;
232         numDimsC = numDimsB;
233         dimsptr = dimsB;
234       } else {
235         strideB = 0;
236       }
237     }
238
239     for(i=2; i<numDimsC; i++)
240       numPages *= dimsptr[i];
241
242     dimsC = (int *)mxMalloc(numDimsC*sizeof(int));
243     dimsC[0] = m;
244     dimsC[1] = n;
245     for(i=2; i<numDimsC; i++)
246       dimsC[i] = dimsptr[i];
247    
248     plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
249     C = mxGetPr(plhs[0]);
250     A = mxGetPr(mxA);
251     B = mxGetPr(mxB);
252    
253     for(i=0; i<numPages; i++) {
254       BLASCALL(dgemm)("N", "N", &m, &n, &p, &one, A + i*strideA, &m, B + i*strideB,
255                       &p, &zero, C + i*strideC, &m);
256     }
257  
258     mxFree(dimsC);
259     break;
260
261   case COMMAND_BACKSLASH:
262     /******************************************
263      * BACKSLASH
264      ******************************************/
265
266     if(dimsA[0] != dimsB[0])
267       mexErrMsgTxt("First dimensions must match");
268    
269     m = dimsA[0];
270     n = dimsA[1];
271     p = dimsB[1];
272    
273     strideC = n*p;
274     strideA = m*n;
275     strideB = m*p;
276     dimsptr = dimsA;
277     numDimsC = numDimsA;
278
279     if(numDimsA != numDimsB) {
280       if(numDimsA < numDimsB) {
281         strideA = 0;
282         numDimsC = numDimsB;
283         dimsptr = dimsB;
284       } else {
285         strideB = 0;
286       }
287     }
288     for(i=2; i<numDimsC; i++)
289       numPages *= dimsptr[i];
290     dimsC = (int *)mxMalloc(numDimsC*sizeof(int));
291     dimsC[0] = n;
292     dimsC[1] = p;
293     for(i=2; i<numDimsC; i++)
294       dimsC[i] = dimsptr[i];
295    
296     plhs[0] = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
297
298     C = mxGetPr(plhs[0]);
299     A = mxGetPr(mxA);
300     B = mxGetPr(mxB);
301  
302     ipivot = (int *)mxMalloc(m*sizeof(int));
303     iwork = (int *)mxMalloc(m*sizeof(int));
304     work = (double *)mxMalloc(m*m*sizeof(double));
305     scratchA = (double *)mxMalloc(m*n*sizeof(double));
306
307
308     if(numDimsA < numDimsB) {
309       /* Single A, multiple B.  That means do one LU on A, and multiple solves */
310       /* Save memory by doing it this way... that way we need only a m*n temp array */
311       memcpy(scratchA, A, mxGetNumberOfElements(mxA)*sizeof(double));
312       memcpy(C, B, m*p*numPages*sizeof(double));
313       compute_lu(scratchA, m, ipivot, work, iwork, 1);
314
315       /* Loop over pages of B and compute */
316       for(i=0; i<numPages; i++) {
317         BLASCALL(dgetrs)("N", &m, &p, scratchA, &m, ipivot, C + i*strideC, &m, &info);
318         blas_return_check(info, "DGETRS");
319       }
320     } else {
321       /* Multiple A.  Do the LU each step through */
322       for(i=0; i<numPages; i++) {
323         memcpy(scratchA, A + i*strideA, m*n*sizeof(double));
324         compute_lu(scratchA, m, ipivot, work, iwork, 1);
325
326         /* Compute */
327         memcpy(C+i*strideC, B+i*strideB, m*p*sizeof(double));
328         BLASCALL(dgetrs)("N", &m, &p, scratchA, &m, ipivot, C + i*strideC, &m, &info);
329         blas_return_check(info, "DGETRS");     
330       }
331     }
332    
333     mxFree(iwork);
334     mxFree(scratchA);
335     mxFree(dimsC);
336     mxFree(ipivot);
337     mxFree(work);
338     break;
339   case COMMAND_INV:
340     /******************************************
341      * INVERSE
342      ******************************************/
343     m = dimsA[0];
344     n = dimsA[1];
345     ipivot = (int *)mxMalloc(m*sizeof(int));
346     work = (double *)mxMalloc(m*m*sizeof(double));
347     iwork = (int *)mxMalloc(m*sizeof(int));
348
349     plhs[0] = mxDuplicateArray(mxA);
350     C = mxGetPr(plhs[0]);
351     strideC = n*m;
352
353     for(i=2; i<numDimsA; i++)
354       numPages *= dimsA[i];
355
356     for(i=0; i<numPages; i++) {
357       compute_lu(C + i*strideC, m, ipivot, work, iwork, 1);
358        
359       BLASCALL(dgetri)(&n, C + i*strideC, &m, ipivot, work, &n, &info );
360       blas_return_check(info, "DGETRI");
361       /*      if(info>0) mexWarnMsgTxt("Matrix is singular to working precision"); */
362     }
363    
364     mxFree(ipivot);
365     mxFree(work);
366     mxFree(iwork);
367     break;
368    
369   default:
370     mexErrMsgTxt("Should never get here");
371   }
372
373 }
374
375 /* Wrapper function for LU decomposition.  Optionally checks
376    singularity of result.  For efficiency, pass in the scratch
377    buffers.  Result appears in-place.  See BLAS docs on DGETRF and
378    DGECON for required scratch buffer sizes.  */
379 void compute_lu(double *X, int m, int *ipivot, double *work, int *iwork, int check_singular)
380 {
381   double anorm, rcond;
382   int info;
383   char errmsg[255];
384  
385   anorm = compute_norm(X, m, m);
386   BLASCALL(dgetrf)(&m, &m, X, &m, ipivot, &info); /* LU call */
387   blas_return_check(info, "DGETRF");
388  
389   if(check_singular) {
390     /* Check singularity */
391     if(info>0)
392       mexWarnMsgTxt("Matrix is singular to working precision");
393     else {
394       BLASCALL(dgecon)("1", &m, X, &m, &anorm, &rcond, work, iwork, &info);
395       blas_return_check(info, "DGECON");
396      
397       if(rcond < eps) {
398         sprintf(errmsg, "%s\n         %s RCOND = %e.",
399                 "Matrix is close to singular or badly scaled.",
400                 "Results may be inaccurate.", rcond);
401         mexWarnMsgTxt(errmsg);
402       }
403     }
404   }
405
406 }
407
408 /* Check the INFO parameter of a BLAS call and error with a useful message if negative */
409 void blas_return_check(int info, const char *blasfcn)
410 {
411   char errmsg[255];
412
413   if(info < 0) {
414     sprintf(errmsg, "Internal error: Illegal %s call, problem in arg %i", blasfcn,
415             abs(info));
416     mexErrMsgTxt(errmsg);
417   }
418 }
419
420 double compute_norm(double *A, int m, int n)
421 {
422   int i, j;
423   double sum;
424   double curmax = 0.0;
425
426   for(j=0; j<n; j++) {
427     sum = 0;
428     for(i=0; i<m; i++) {
429       sum += fabs(A[m*j + i]);
430     }
431     if(sum > curmax)
432       curmax = sum;
433   }
434   return(curmax);
435 }
Note: See TracBrowser for help on using the browser.