高斯消元求解线性方程组(C++)

大一的时候在线性代数课程里学过,要选择列中最大的数作为pivot,目的是为了减少计算中产生的误差,昨晚自己写了一遍代码,发现果然如此。实践一下,更有助于学习!

测试数据:

A:
1.00 2.00 1.00 4.00
2.00 0.00 4.00 3.00
4.00 2.00 2.00 1.00
-3.00 1.00 3.00 2.00

B:
13.00
28.00
20.00
6.00

C=B/A:
3.00
-1.00
4.00
2.00

D=A*B:
113.00
124.00
154.00
61.00

#include <cassert>
#include <cstring>
#include <cmath>
#include <cstdio>

class Matrix{
private:
	int rows, cols;
	float* data;
public:
	int getRows()const{
		return rows;
	}
	int getColumns()const{
		return cols;
	}
	Matrix(): rows(0), cols(0), data(0){}
	Matrix(int r, int c){
		rows = r, cols = c;
		data = new float[r*c];
		this->zeros();
	}
	Matrix(int r, int c, float** d){
		rows = r, cols = c;
		data = new float[r*c];
		memcpy(data, d, sizeof(float)*r*c);
	}
	Matrix(const Matrix& m){
		rows = m.rows, cols = m.cols;
		data = new float[rows*cols];
		memcpy(data, m.data, sizeof(float)*rows*cols);
	}
	~Matrix(){
		delete data;
	}
	void set(int r, int c, float val){ 
		if(r*cols+c>=rows*cols)
			throw "Out of boundary!"; 
		data[r*cols+c] = val; 
	}
	float get(int r, int c) const { return data[r*cols+c]; }
	void zeros(){ memset(data, 0, sizeof(float)*rows*cols); }
	Matrix& copy(const Matrix& m, int r, int c){
		return copy(m, 0, 0, m.rows, m.cols, r, c);
	}
	Matrix& copy(const Matrix& m, int r, int c, int nr, int nc, int r0, int c0){
		for(int i=0; i<nr; i++)
			for(int j=0; j<nc; j++)
				this->set(i+r0, j+c0, m.get(i+r, j+c));
		return *this;
	}
	Matrix& swapRows(int r1, int r2){
		Matrix C(1, this->cols);
		C.copy(*this, r1, 0, 1, this->cols, 0, 0);
		this->copy(*this, r2, 0, 1, this->cols, r1, 0)
			.copy(C, 0, 0, 1, this->cols, r2, 0);
	}
	Matrix& operator = (const Matrix& m){
		delete data;
		this->rows = m.rows, this->cols = m.cols;
		data = new float[this->rows*this->cols];
		memcpy(data, m.data, sizeof(float)*rows*cols);
		return *this;
	}
	Matrix operator * (const Matrix& m){
		Matrix res(this->rows, m.cols);
		if(this->cols!=m.rows)
			throw "Matrix failed to multiply!";
		for(int i=0; i<res.rows; i++)
			for(int j=0; j<res.cols; j++)
				for(int k=0; k<this->cols; k++)
					res.data[i*res.cols+j] +=  this->get(i, k) * m.get(k, j);
		return res;
	}
	Matrix operator / (const Matrix& A){
		/* Gaussian Elimination
		 * AX=B   X=B/A */
		if(A.rows!=A.cols || this->cols!=1)
			throw "Invalid matrix A or B!";
		Matrix X(A.cols, 1), &B=*this, Aug(A.rows, A.cols+1);
		/* Make a augumented matrix */
		Aug.copy(A, 0, 0).copy(B, 0, A.cols);
		/* sort the pivots */
		for(int i=0; i<Aug.rows; i++){
			int maxi = i;
			for(int j=i+1; j<Aug.rows; j++)
				if(fabs(Aug.get(j, i)) > fabs(Aug.get(maxi, i)))
					maxi = j;
			if(maxi!=i) Aug.swapRows(i, maxi);
			if(Aug.get(i, i)==0) throw "Singular matrix!";
			for(int k=i+1; k<Aug.rows; k++){
				float m=Aug.get(k, i) / Aug.get(i, i);
				for(int v=0; v<Aug.cols; v++)
					Aug.set(k, v, Aug.get(k, v) - m*Aug.get(i, v));
			}
		}
		/* Back Substitution */
		int z = X.rows - 1;
		X.set(z,0, Aug.get(z, X.rows) / Aug.get(z, z));
		for(z--; z>=0; z--){
			float m = 0;
			for(int p=z+1; p<X.rows; p++)
				m += Aug.get(z, p) * X.get(p, 0);
			X.set(z, 0, (Aug.get(z, X.rows) - m) / Aug.get(z, z));
		}
		return X;
	}
	void print(){
		for(int i=0; i<rows; i++){
			for(int j=0; j<cols-1; j++)
				printf("%.2f ", data[i*cols+j]);
			printf("%.2f\n", data[i*cols+cols-1]);
		}
		printf("\n");
	}
};

int main(int argc, char **argv)
{
	float a[4][4]={	1.0, 2.0, 1.0, 4.0,
					2.0, 0.0, 4.0, 3.0,
					4.0, 2.0, 2.0, 1.0,
					-3.0, 1.0, 3.0, 2.0 };
	float b[4] = {13.0, 28.0, 20.0, 6.0 };
	Matrix A(4, 4, (float**)a), B(4, 1, (float**)b);
	puts("A:");
	A.print();
	puts("B:");
	B.print();
	Matrix C = B / A;
	puts("C=B/A:");
	C.print();
	Matrix D = A * B;
	puts("D=A*B:");
	D.print();
	return 0;
}

高斯消元求解线性方程组(C++)》有4个想法

  1. 西北木

    我用matlab计算c=b/a计算不出来 c=b\a计算结果为
    >> a=[1,2,1,4;2,0,4,3;4,2,2,1;-3,1,3,2];
    >> b=[13;28;20;6];
    >> d=a*b

    d =

    113
    124
    154
    61

    >> c=b\a

    c =

    0.0943 0.0518 0.1317 0.1210

    >> c=b/a
    ??? Error using ==> mrdivide
    Matrix dimensions must agree.

    回复
    1. Xiaoxia 文章作者

      呃。。。我本来也想用 \ 运算符的,但是据说C++里不能重载这个运算符,所以我才用了 / 。其实就是求解线性方程组AX=B。

      回复
    1. Xiaoxia 文章作者

      嗯,用Matrix能够很方便地表示一些数据嘛。个人觉得,其实就是一个二位数组。矩阵这名起得好听了~

      回复

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据