module blocklu
  use RealKind
contains
  subroutine lu1blas3(A,n,bsize,ierror)
    ! Gaussian Elimination without pivoting
    ! using blocking and blas level 3 for a square n X n matrix
    ! Assumes the blocksize divides n 
    integer, intent(in) :: n, bsize
    real (kind=rk), intent(inout), dimension(n,n) :: A
    integer, intent(out) :: ierror       ! if this has value 1 on output then 
                                       ! the blocksize does not divide n
    integer :: nb                        ! nb = number of blocks
    integer ::  i, j
    ierror = 0
    if(n/bsize /= real(n)/real(bsize)) then
      print*,'error, block size does not divide dimension of A'
      ierror = 1
      return
    end if
    nb = n/bsize
    print*,'number of blocks', nb
    do j = 1,nb-1
      i = (j-1)*bsize + 1
      call blu1blas2(A,n,bsize,i)   ! LU decomposition of bsize x bsize 
                                    ! block with leading entry A(i,i)  
                                    ! Note that this block is overwritten 
                                    ! with the LU factors 

      call dtrsm('Left','Lower','No Transpose','Unit Diagonal', &
                 bsize,(nb-j)*bsize,1.0_rk, &
             A(i,i),n,A(i,i+bsize),n)    ! Step 2 of the algorithm: 
                                         ! a unit lower triangular solve 
                                         ! with many right-hand sides

      call dtrsm('Right','Upper','No Transpose','Non Unit Diagonal', &
                 (nb-j)*bsize,bsize,1.0_rk, &
                 A(i,i),n,A(i+bsize,i),n)    ! Step 3 of the algorithm: 
                                             ! an upper  triangular solve 
                                             ! with many right-hand sides

      call dgemm('No transpose','No transpose',  &
                 (nb-j)*bsize,(nb-j)*bsize, bsize, &
                 -1.0_rk,A(i+bsize,i),n,A(i,i+bsize),n, &
                  1.0_rk,A(i+bsize,i+bsize),n)  ! Step 4 of the algorithm:
                                               ! form the remainder matrix 
                                               ! which is still to be 
                                               ! factorised. 
    end do
    ! LU decomposition of final nb x nb  block
    call blu1blas2(A,n,bsize,n-bsize+1)
  end subroutine lu1blas3

  subroutine blu1blas2(A,n,bsize,k)
  ! Gaussian Elimination without pivoting
  ! using blas level 2 for a square bsize X bsize submatrix
  ! with top left entry A(k,k) of a square n x n matrix A.

    integer, intent(in) :: n,bsize,k
    real (kind=rk), intent(inout), dimension(n,n) :: A
    integer i
    do i = 1,bsize-1
      call dscal(bsize-i,1.0_rk/A(k+i-1,k+i-1),A(k+i,k+i-1),1)
      call dger(bsize-i,bsize-i,-1.0d0,A(k+i,k+i-1),1,A(k+i-1,k+i),n, &
                                                          A(k+i,k+i),n)
    end do
  end subroutine blu1blas2

  subroutine bs1(A,b,n,x)
  ! Back substitution to solve A*x = b
  ! assuming that A contains its LU factors
    integer, intent(in) :: n
    real (kind=rk), intent(in), dimension(n,n) :: A
    real (kind=rk), intent(in), dimension(n) :: b
    real (kind=rk), intent(out), dimension(n) :: x
    real (kind=rk), dimension(n) :: y
    integer i,j
    ! first solve L*y = b
    y(1)  = b(1)
    do i = 2,n
      y(i) = b(i)
      do j= 1,i-1
        y(i) = y(i) - A(i,j)*y(j)
      end do
    end do
    ! Now solve U*x = y 
    x(n) = y(n)/A(n,n);
    do i = n-1,1,-1
      x(i) = y(i)
      do j = i+1,n
        x(i) = x(i) - A(i,j)*x(j)
      end do
      x(i) = x(i)/A(i,i)
    end do
  end subroutine bs1
end module