nmatrix.h

Go to the documentation of this file.
00001 #ifndef __nmatrix_h
00002 #define __nmatrix_h
00003 
00004 #include  "nvector.h"
00005 #include  "giota.h"
00006 
00007 /**************************************************/
00008 
00009 namespace Gascoigne
00010 {
00011 template<class T>
00012 class nmatrix : public nvector<T>
00013 {
00014 
00015   int nn,mm;
00016   
00017 public:
00018 
00019   typedef typename  nvector<T>::const_iterator  const_iterator;
00020   typedef typename  nvector<T>::iterator        iterator;
00021 
00022   /*
00023     Constructeurs
00024    */
00025 
00026   nmatrix<T>()             : nvector<T>()   , nn(0), mm(0) {}
00027   nmatrix<T>(size_t n)        : nvector<T>(n*n), nn(n), mm(n) {}       
00028   nmatrix<T>(size_t n, size_t m) : nvector<T>(n*m), nn(n), mm(m) {}       
00029   nmatrix<T>(size_t n, size_t m, const T& d) : nvector<T>(n*m,d), nn(n), mm(m) {}       
00030 
00031 
00032   /*
00033     Fonctions memoire
00034    */
00035   //void reservesize(size_t n, size_t m, T s = static_cast<T>(0))
00036   void reservesize(size_t n, size_t m, T s)
00037     {
00038                 nn = n;
00039                 mm = m;
00040                 nvector<T>::reservesize(n*m,s);
00041     }
00042   void reservesize(const nmatrix<T>& M)
00043     {
00044                 nmatrix<T>::reservesize(M.n(),M.m(),0);
00045     }
00046   void memory(size_t n, size_t m)
00047     {
00048       //nmatrix<T>::reservesize(n,m);
00049       nmatrix<T>::reservesize(n,m,0);
00050     }
00051   void resize(size_t n)
00052     {
00053       nmatrix<T>::memory(n,n);
00054    }
00055    void resize(size_t n, size_t m)
00056     {
00057       nmatrix<T>::memory(n,m);
00058     }
00059 
00060   /*
00061     Fonctions d'acces
00062    */
00063 
00064   size_t      n()                     const  { return nn; }
00065   size_t      m()                     const  { return mm; }
00066   const T& operator()(int i,int j) const  { return (*this)[j+mm*i];   }
00067   T&       operator()(int i,int j)        { return (*this)[j+mm*i];   }
00068   const T& value     (int i,int j) const  { return (*this)[j+mm*i];   }
00069   T&       value     (int i,int j)        { return (*this)[j+mm*i];   }
00070 
00071   nmatrix<T>&   operator=(const nmatrix<T>& A)
00072     {
00073       if( (n()!=A.n()) || (m()!=A.m()) )
00074         {
00075           std::cerr << "nmatrix<T>::operator=(const nmatrix<T>& A)\n";
00076           std::cerr << "no = possible " << n()<<" "<<A.n()<<" "<<m()<<" "<<A.m()<<std::endl; abort();
00077         }
00078       for(int i=0;i<n();i++)
00079         {
00080           for(int j=0;j<m();j++)
00081             {
00082               (*this)(i,j) = A(i,j);
00083             }
00084         }
00085       return *this;
00086    }
00087 
00088   const_iterator rowstart(int i) const {return std::vector<T>::begin()+mm*i;}
00089   
00090   void identity()
00091     {
00092     nvector<T>::zero();
00093       for(int i=0;i<n();i++)
00094         {
00095           (*this)(i,i) = 1.;
00096         }
00097     }
00098 
00099   /*
00100     Fonctions numeriques
00101    */
00102   T det() const
00103     {
00104       const nmatrix<T>& A = *this;
00105       T d;
00106       if( (n()==2) && (m()==2))
00107         {
00108           d = A(0,0)*A(1,1)-A(0,1)*A(1,0);
00109         }
00110       else if( (n()==3) && (m()==3))
00111         {
00112           d = A(0,0) * ( A(1,1)*A(2,2)-A(1,2)*A(2,1) ) 
00113             - A(1,0) * ( A(0,1)*A(2,2)-A(0,2)*A(2,1) ) 
00114             + A(2,0) * ( A(0,1)*A(1,2)-A(0,2)*A(1,1) );
00115         }
00116       else
00117         {
00118           std::cerr << " cul de sac dans nmatrix::det() "<<n()<<" "<<m()<<std::endl;
00119           exit(1);
00120         }
00121       return d;
00122     }
00123 
00124   void gram(const nmatrix<T>& B)
00125     {
00126       nvector<T>::zero();
00127       for(int i=0;i<n();i++)
00128         {
00129           for(int j=0;j<m();j++)
00130             {
00131               for(int k=0;k<B.n();k++)
00132                 {
00133                   (*this)(i,j) += B(k,i) * B(k,j);
00134                 } 
00135             }
00136         }
00137     }
00138   void transpose()
00139     {
00140       nmatrix<T> B(*this);
00141       resize(m(),n());
00142       for(int i=0;i<m();i++)
00143         {
00144           for(int j=0;j<n();j++)
00145             {
00146               (*this)(i,j) = B(j,i);
00147             }
00148         }
00149     }
00150 
00151   void mmult(nmatrix<T>& A, const nmatrix<T>& B) const
00152     {
00153       A.zero();
00154       for(int i=0;i<A.n();i++)
00155         {
00156           for(int j=0;j<A.m();j++)
00157             {
00158               for(int k=0;k<m();k++)
00159                 {
00160                   A(i,j) += (*this)(i,k) * B(k,j);
00161                 } 
00162             }
00163         }
00164     }
00165 
00166   void mmult_ad(nmatrix<T>& A, const nmatrix<T>& B) const
00167     {
00168       /* A = (*this)^T * B */
00169       A.zero();
00170       for(int i=0;i<A.n();i++)
00171         {
00172           for(int j=0;j<A.m();j++)
00173             {
00174               for(int k=0;k<m();k++)
00175                 {
00176                   A(i,j) += (*this)(k,i) * B(k,j);
00177                 } 
00178             }
00179         }
00180     }
00181 
00182   template<class VECTOR>
00183     void Mult(VECTOR& y, const VECTOR& x, double s=1.) const
00184     {
00185       const_iterator                    p  = std::vector<T>::begin();
00186       typename VECTOR::iterator         py = y.begin();
00187       typename VECTOR::const_iterator   px;
00188       
00189       while(p!=std::vector<T>::end())
00190         {
00191           px = x.begin();
00192           for(int j=0;j<m();j++)
00193             {
00194               *py += s * (*p++) * (*px++);
00195             }
00196           py++;
00197         }
00198     }
00199   template<class VECTOR>
00200     void mult(VECTOR& y, const VECTOR& x) const
00201     {
00202       const_iterator                    p  = std::vector<T>::begin();
00203       typename VECTOR::iterator         py = y.begin();
00204       typename VECTOR::const_iterator   px;
00205       
00206       while(p!=std::vector<T>::end())
00207         {
00208           px = x.begin();
00209           for(int j=0;j<m();j++)
00210             {
00211               *py += (*p++) * (*px++);
00212             }
00213           py++;
00214         }
00215     }
00216   template<class VECTOR>
00217     void multtrans(VECTOR& y, const VECTOR& x, double s=1.) const
00218     {
00219       const_iterator                              p  = std::vector<T>::begin();
00220       typename VECTOR::iterator         py = y.begin();
00221       typename VECTOR::const_iterator   px = x.begin();
00222       
00223       while(p!=std::vector<T>::end())
00224         {
00225           py = y.begin();
00226           for(int j=0;j<m();j++)
00227             {
00228               (*py) += s* (*p++) * (*px);
00229               py++;
00230             }
00231           px++;
00232         }
00233     }
00234 
00235   template<class VECTOR>
00236     void multeq(VECTOR& y, const VECTOR& x, double s=1.) const
00237     {
00238       const_iterator                              p  = std::vector<T>::begin();
00239       typename VECTOR::iterator         py = y.begin();
00240       typename VECTOR::const_iterator   px;
00241       
00242       while(p!=std::vector<T>::end())
00243         {
00244           px = x.begin();
00245           *py = 0.;
00246           for(int j=0;j<m();j++)
00247             {
00248               *py += s * (*p++) * (*px++);
00249             }
00250           py++;
00251         }
00252     }
00253 
00254   template<class VECTOR>
00255     void mult_ad(VECTOR& y, const VECTOR& x) const
00256     {
00257       // vmult with the adjoint matrix
00258       const_iterator                              p  = std::vector<T>::begin();
00259       typename VECTOR::iterator         py = y.begin();
00260       typename VECTOR::const_iterator   px = x.begin();
00261       
00262       while(p!=std::vector<T>::end())
00263         {
00264           py = y.begin();
00265           for(int j=0;j<m();j++)
00266             {
00267               (*py) += (*p++) * (*px);
00268               py++;
00269             }
00270           px++;
00271         }
00272     }
00273 
00274   template<class ITER1, class ITER2>
00275     void mult_ad(ITER1 py, ITER2 px, double s=1.) const
00276     {
00277       // vmult with the adjoint matrix
00278       const_iterator    p  = std::vector<T>::begin();
00279       
00280       while(p!=std::vector<T>::end())
00281         {
00282           for(int j=0;j<m();j++)
00283             {
00284               (*py) += s*(*p++) * (*px);
00285               py++;
00286             }
00287           py -= m();
00288           px++;
00289         }
00290     }
00291 
00292   template<class ITER1, class ITER2>
00293   void mult(ITER1 py, ITER2  px0, double s=1.) const
00294   {
00295     const_iterator    p  = std::vector<T>::begin();
00296     ITER2   px;
00297     
00298     while(p!=std::vector<T>::end())
00299       {
00300         px = px0;
00301         for(int j=0;j<m();j++)
00302           {
00303             *py += s* (*p++) * (*px++);
00304           }
00305         py++;
00306       }
00307   }
00308 
00309   template<class VECTOR>
00310     void multeq_ad(VECTOR& y, const VECTOR& x) const
00311     {
00312       // vmulteq with the adjoint matrix
00313       const_iterator                              p  = std::vector<T>::begin();
00314       typename VECTOR::iterator         py = y.begin();
00315       typename VECTOR::const_iterator   px = x.begin();
00316       
00317       y.zero();
00318       while(p!=std::vector<T>::end())
00319         {
00320           for(int j=0;j<m();j++)
00321             {
00322               (*py) += (*p++) * (*px);
00323               py++;
00324             }
00325           py -= m();
00326           px++;
00327         }
00328     }
00329 
00330   template<class ITER1, class ITER2>
00331     void multeq_ad(ITER1 py, ITER2 px) const
00332     {
00333       // vmult with the adjoint matrix
00334       const_iterator    p  = std::vector<T>::begin();
00335       
00336       for(int i=0;i<m();i++)
00337         {
00338           (*py++) = 0.;
00339         }
00340       py -= m();
00341       while(p!=std::vector<T>::end())
00342         {
00343           for(int j=0;j<m();j++)
00344             {
00345               (*py) += (*p++) * (*px);
00346               py++;
00347             }
00348           py -= m();
00349           px++;
00350         }
00351     }
00352 
00353   /**************************************************/
00354   
00355   void lu()
00356     {
00357       /* LU decomposition */
00358       
00359       for(int i=1;i<n();i++)
00360         {
00361           for(int k=0;k<i;k++)
00362             {
00363               value(i,k) /= value(k,k);
00364               for(int j=k+1;j<n();j++)
00365                 {
00366                   value(i,j) -= value(i,k)*value(k,j);
00367                 }
00368             }
00369         }
00370     }
00371 
00372   /**************************************************/
00373   
00374   void inverse2()
00375     {
00376       T a = value(0,0);
00377       T b = value(0,1);
00378       T c = value(1,0);
00379       T d = value(1,1);
00380       value(0,0) = d;
00381       value(0,1) = -b;
00382       value(1,0) = -c;
00383       value(1,1) = a;
00384 
00385       T idet = 1./(a*d-b*c);
00386       *this *= idet;
00387     }
00388 
00389   /**************************************************/
00390   
00391   void inverse3()
00392     {
00393       T idet = 1./det();
00394 
00395       T a = value(0,0);
00396       T b = value(0,1);
00397       T c = value(0,2);
00398       T d = value(1,0);
00399       T e = value(1,1);
00400       T f = value(1,2);
00401       T g = value(2,0);
00402       T h = value(2,1);
00403       T i = value(2,2);
00404       value(0,0) = e*i-f*h;
00405       value(0,1) = h*c-b*i;
00406       value(0,2) = b*f-e*c;
00407       value(1,0) = f*g-d*i;
00408       value(1,1) = a*i-g*c;
00409       value(1,2) = d*c-a*f;
00410       value(2,0) = d*h-g*e;
00411       value(2,1) = g*b-a*h;
00412       value(2,2) = a*e-b*d;
00413 
00414       *this *= idet;
00415     }
00416 
00417   /**************************************************/
00418   
00419   void inverse()
00420     {
00421       if (n()==2)
00422         {
00423           inverse2();
00424           return;
00425         }
00426       if (n()==3)
00427         {
00428           inverse3();
00429           return;
00430         }
00431       /* LU decomposition */
00432       
00433       for(int i=1;i<n();i++)
00434         {
00435           for(int k=0;k<i;k++)
00436             {
00437               value(i,k) /= value(k,k);
00438               for(int j=k+1;j<n();j++)
00439                 {
00440                   value(i,j) -= value(i,k)*value(k,j);
00441                 }
00442             }
00443         }
00444 
00445       /* Inverse von L */
00446       
00447       for(int ncol=0;ncol<n()-1;ncol++)
00448         {
00449           for(int i=ncol+1;i<n();i++)
00450             {
00451               value(i,ncol) = -value(i,ncol);
00452               for(int k=ncol+1;k<i;k++)
00453                 {
00454                   value(i,ncol) -= value(i,k)*value(k,ncol);
00455                 }
00456             }
00457         }
00458 
00459 
00460       /* Inverse von U */
00461       
00462       
00463       for(int nlin=0;nlin<n();nlin++)
00464         {
00465           for(int j=nlin+1;j<n();j++)
00466             {
00467               value(nlin,j) /= -value(nlin,nlin);
00468               for(int k=nlin+1;k<j;k++)
00469                 {
00470                   value(nlin,j) -= value(nlin,k)*value(k,j);
00471                 }
00472               value(nlin,j) /= value(j,j);
00473             }
00474           value(nlin,nlin) = 1./value(nlin,nlin);
00475         }
00476       
00477       
00478       /* Inverse von A */
00479       
00480       for(int ncol=0;ncol<n();ncol++)
00481         {
00482           for(int i=0;i<ncol+1;i++)
00483             {
00484               for(int k=ncol+1;k<n();k++)
00485                 {
00486                   value(i,ncol) += value(i,k)*value(k,ncol);
00487                 }
00488             }
00489           for(int i=ncol+1;i<n();i++)
00490             {
00491               value(i,ncol) *= value(i,i);
00492               for(int k=i+1;k<n();k++)
00493                 {
00494                   value(i,ncol) += value(i,k)*value(k,ncol);
00495                 }
00496             }
00497         }
00498     }
00499 
00500   void gauss_jordan()
00501     {
00502       nvector<int> p(n());
00503       iota(p.begin(),p.end(),0);
00504       
00505       for (int j=0;j<n();j++)
00506         {
00507           double max = fabs(value(j,j));
00508           int r = j;
00509           for (int i=j+1;i<n();i++)
00510             {
00511               if (fabs(value(i,j)) > max)
00512                 {
00513                   max = fabs(value(i,j));
00514                   r = i;
00515                 }
00516             }
00517           if (r>j)
00518             {
00519               for (int k=0; k<n(); k++)
00520                 {
00521                   //swap(value(j,k),value(r,k));
00522                   T h        = value(j,k);
00523                   value(j,k) = value(r,k);
00524                   value(r,k) = h;
00525                 }
00526               //swap(p[j],p[r]);
00527               int h = p[j]; p[j] = p[r]; p[r] = h;
00528             }
00529           
00530           double hr = 1./value(j,j);
00531           value(j,j) = hr;
00532           for (int k=0;k<n();k++)
00533             {
00534               if (k==j) continue;
00535               for (int i=0;i<n();i++)
00536                 {
00537                   if (i==j) continue;
00538                   value(i,k) -= value(i,j)*value(j,k)*hr;
00539                 }
00540             }
00541           for (int i=0;i<n();i++)
00542             {
00543               value(i,j) *= hr;
00544               value(j,i) *= -hr;
00545             }
00546           value(j,j) = hr;
00547         }
00548       nvector<double> hv(n());
00549       for (int i=0;i<n();i++)
00550         {
00551           for (int k=0;k<n();k++) hv[p[k]] = value(i,k);
00552           for (int k=0;k<n();k++) value(i,k) = hv[k];
00553         }
00554     }
00555 };
00556 }
00557 
00558 #endif
00559 

Generated on Thu Sep 14 10:34:35 2006 for Gascoigne by  doxygen 1.4.7