#!/usr/bin/env python
from scipy import *
from time import time
import Gnuplot, Gnuplot.funcutils
import quad
import pdb
from decimal import *
# Function definitions: -------------------------------
def initialise(ny_in,\
         p4a_in,p4t_in,p4f_in,p4h_in,p4r_in,p4b_in,p0b_in,\
         lambda_0_in,lambda_1_in):
  global ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  global lambda_0,lambda_1
  ny=ny_in
  p4a=p4a_in
  p4t=p4t_in
  p4f=p4f_in
  p4h=p4h_in
  p4r=p4r_in
  p4b=p4b_in
  p0b=p0b_in
  lambda_0=lambda_0_in
  lambda_1=lambda_1_in
  print \
  'parameters:',ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b,lambda_0,lambda_1

def f2(x): # used for the RHS (checking ny=0.1:0.9):
  ff=ones((len(x)))
  return ff

def f(x): # was used for the RHS in all calculations in paper...:
           #   (with known solution)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4h:
    ff90=zeros(len(x))
    quad.quadruple.calculate_f_fii(ff90)
  lx=len(x)
  if lx/2*2==lx: # even number
    odd=0
    ll=lx/2
    l1=ll
  else: # odd number
    odd=1
    lx=lx-1
    ll=lx/2
    l1=ll+1
  ff=zeros((len(x)))
  ff[0:l1]=1.-pi/2.-2.*x[0:l1]**0.5-2*(1.-x[0:l1])**0.5-x[0:l1]*log(1.+(1.-x[0:l1])**0.5)-(1.-x[0:l1])*log(1.+x[0:l1]**0.5)\
     +0.5*x[0:l1]*log(x[0:l1])+0.5*(1.-x[0:l1])*log(1.-x[0:l1])
  # Using the symmetry of f(fii):
  if odd:
    ff[lx:lx-ll:-1]=ff[0:ll]
  else:
    ff[lx-1:lx-ll-1:-1]=ff[0:ll]
  #pdb.set_trace() # switch on debugger ################################
  #for i in xrange(len(x)):
  #  if isnan(ff[i]): # IMPORTANT! - to avoid nan-s!
  #    #ff[i]=-2.5707963584050351
  #    ff[i]=-2.5707996199794597
  #    print 'nan avoided at:',i,x[i],ff90[i],'repl. with quadruple now'
  #    ff[i]=ff90[i]
  if p4h:
    #doplot(abs(ff90-ff),'t','sol(fi) differences')
    #doplot(abs(ff90-ff),'t','symmetrised sol(fi) differences')
    #doplot(ff90,'t','sol(fi)')
    print 'max diff between double and quadruple f(fii) is:',abs(ff90-ff).max()
    return ff90
  else:
    return ff

def known_solution(x):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4r:
    uf90=zeros(len(x))
    quad.quadruple.calculate_known_sol_fii(uf90)
  #using symmetry of sol:
  lx=len(x)
  if lx/2*2==lx: # even number
    odd=0
    ll=lx/2
    l1=ll
  else: # odd number
    odd=1
    lx=lx-1
    ll=lx/2
    l1=ll+1
  u=zeros((len(x)))
  u[0:l1]=1.+x[0:l1]**0.5+(1.-x[0:l1])**0.5
  # Using the symmetry of sol:
  if odd:
    u[lx:lx-ll:-1]=u[0:ll]
  else:
    u[lx-1:lx-ll-1:-1]= u[0:ll]
  if p4r:
    print 'max diff between double and quadruple u(fii) is:',abs(uf90-u).max()
    #doplot(u,'t','sol(fi)')
    #doplot(abs(uf90-u),'t','sol(fi) differences')
    return uf90
  else:
    return u

def a(t,s): # should return a matrix actually for better performance...
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return 1.

def b(t,s):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return 0.

# # startup funcions:
def calculate_c_star(r_0,r_1): #(24)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  # Variant 1
  #c=0.0
  #sign=-1.0
  #for k in xrange(0,r_1):
  #  sign=-sign
  #  c=c+1.*sign/(r_0+k)*comb(r_1-1,k)
  #  #print 'c=',c
  #print 'old c_star-cf90:',c-cf90
  # Variant 2:
  rr=r_0+r_1-1
  c=1.0
  for k in range(0,rr+1):
    if k>0 and k<r_1:
      c=c*k
    kk=rr-k
    if kk>=r_0:
      c=c/kk
  # check:
  if p4a+p4t+p4f+p4r or p4b==1:
    cf90=zeros(1)
    quad.quadruple.calculate_c_star(r_0,r_1,n,m,m_0,m_1,ny,cf90)
    print '    c_star-cf90:',abs(c-cf90)
  return c

def calculate_alpha_p():
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4a:
    quad.quadruple.calculate_alpha_p()
  #if m==4:
  #  alphap=zeros((m_0+m_1+1),dtype=float)
  #  alphap[0]=1.0/36.0  #-2,4
  #  alphap[1]=-5.0/18.0 #-1,4
  #  alphap[2]=3.0/2.0   # 0,4
  #  alphap[3]=alphap[1] # 1,4
  #  alphap[4]=alphap[0] # 2,4
  #else:
  Pm=calculate_charPm()
  allroots=Pm.r.real
  r_10=zeros(m_0) # roots between (-1,0)
  k=0
  for i in range(len(allroots)):
    if allroots[i]>-1 and allroots[i]<0:
      r_10[k]=allroots[i]
      k=k+1                              
  r_10.sort()
  ##r_10=r_10[::-1] # is this needed?
  Pm_deriv=Pm.deriv() # derivative
  gamma_m=zeros(m_1)
  gamma_m[0]=1.0
  q=arange(0,m_1)
  for l in range(m_0): # (15) in the paper
    gamma_m[1:m_1]=gamma_m[1:m_1]+(1.0+r_10[l])*r_10[l]**(m_0+q[1:m_1]-1) \
                   / (1.0-r_10[l])**(2*q[1:m_1]+1) \
                   / Pm_deriv(r_10[l])
  alphap=zeros(m+1)
  for i in range(m+1):
    k=i-m_1+1
    absk=abs(k)
    for j in range(absk,m_1):
      alphap[i]=alphap[i]+(-1)**(k+j)*comb(2*j,k+j)*gamma_m[j]
  # for checking - we know the values in case of m=4:
  #if m==4:
  #  a=zeros(m+1)
  #  a[0]=1.0/36.0  #-2,4
  #  a[1]=-5.0/18.0 #-1,4
  #  a[2]=3.0/2.0   # 0,4
  #  a[3]=a[1]      # 1,4
  #  a[4]=a[0]      # 2,4
  #  print 'alphap=',alphap
  #  print 'a=',a
  #  print 'a-alphap=',a-alphap
  return alphap

def calculate_charPm():
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  x=arange(-m_0,m_0+1)+m/2.
  fact=ones(m+1)
  for i in range(2,m+1):
    fact[i]=fact[i-1]*i # contains factorials now
  mult=ones(m+1)
  sign=1.0
  for i in range(m+1):
    mult[i]=sign*m/fact[i]/fact[m-i]
    sign=-sign
  pm=zeros(len(x))
  for i in range(m+1):
    for j in range(len(x)):
      if x[j]-i > 0.0:
        pm[j]=pm[j]+mult[i]*(x[j]-i)**(m-1)
  Pm=poly1d(pm[::-1])
  return Pm
      
def calculate_fii(t,r_0,r_1):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4a+p4f+p4r:
    rf90=zeros(len(t))
    quad.quadruple.calculate_fi(len(t),t,rf90)
  ff=ones((len(t)),dtype=float)
  tt=ones(len(t))
  tk=ones(len(t))
  for k in xrange(1,r_1):
    tt=tt*(r_0+k-1)/k
    tk=tk*(1-t)
    ff=ff+tt*tk
  ff=ff*t**r_0
  if p4a+p4f+p4r:
    print '    max(abs(fi-fi_rq))',abs(ff-rf90).max()
    #doplot(abs(ff-rf90),'t','fi and fi_rq differences')
  return ff

def calculate_fii_old(t,r_0,r_1,c_star):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4a+p4f+p4r:
    rf90=zeros(len(t))
    quad.quadruple.calculate_fi(len(t),t,rf90)
  ff=zeros((len(t)),dtype=float)
  sign=-1
  tt=ones(len(t))
  for k in xrange(0,r_1):
    sign=-sign
    #print k,tt[0],sign/(r_0+k),1.*sign/(r_0+k)
    ff=ff+1.*sign/(r_0+k)*comb(r_1-1,k)*tt
    #print k,sign,'ff=',ff
    if k<r_1-1:
      tt=tt*t
  #print 'fiiff is:',ff
  #print 'fiiff/c_star is:',ff/c_star
  #print 'fiiff/c_star*t**r_0 is:',ff/c_star*t**r_0
  ff=ff/c_star*t**r_0
  if p4a+p4f+p4r:
    print 'old max(abs(fi-fi_rq))',abs(ff-rf90).max()
  return ff

def calculate_fii_prim(t,r_0,r_1,c_star):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if p4a+p4f:
    quad.quadruple.calculate_fi_prim(len(t),t)
  return t**(r_0-1)*(1-t)**(r_1-1)/c_star

def Get_PHI(t,s,r_0,r_1,fii_prim):
  import os
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(t)
  jdim=len(s)
  if not os.path.exists('PreCalc'):
    os.mkdir('PreCalc')
  filen='PreCalc/PHI_'+str(idim)+'_'+str(jdim)+'_'+str(r_0)+'_'+str(r_1)+'.data'
  tim=time()
  if os.path.exists(filen):
    #if file exist, read from there
    PHI1=fromfile(file=filen)
    print '  PHI read from file:',filen
    tim=time()-tim
    print 'Time for getting PHI (in Python)',tim
    return PHI1.reshape((idim,jdim))
  else:
    #otherwise calculate and save the result
    PHI=calculate_PHI(t,s,r_0,r_1,fii_prim)
    PHI.tofile(file=filen)
    print '  PHI written to file:',filen
    tim=time()-tim
    print 'Time for calculating PHI (in Python)',tim
    return PHI

def calculate_PHI(t,s,r_0,r_1,fii_prim): # (66)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(t)
  jdim=len(s)
  if p4a:
    r1=zeros(idim*jdim)
    quad.quadruple.calculate_phi(len(t),len(s),t,s,r1)
    rf90=r1.reshape(idim,jdim)
  # calculate the factorials.
  r_01=max(r_0,r_1)
  den12=ones(r_01) # these are for the first two denominators
  for p in range(1,r_01):
    pp=r_01-1-p
    den12[pp]=den12[pp+1]*p
  rr=r_0+r_1
  mult=ones(rr) # (r_0+r_1-1)!/(p+q+1)!
  for i in range(rr-2,0,-1):
    mult[i]=mult[i+1]*(i+1)
  # find the coefficients
  ff=ones((r_0,r_1))
  for p in range(r_0):
    for q in range(r_1):
      ff[p,q]=mult[p+q+1]/den12[p]/den12[q]
  sp=zeros((r_0,jdim))
  t1p=zeros((r_1,idim))
  # Calculate t1p powers:
  t1p[r_1-1,:]=1.0
  for q in range(r_1-2,-1,-1):
    t1p[q,:]=t1p[q+1,:]*(1.0-t)
  # s[j] powers
  sp[r_0-1,:]=1.0
  for p in range(r_0-2,-1,-1):
    sp[p,:]=sp[p+1,:]*s
  # version with reordered loops:
  r_2=zeros((idim,jdim))
  tsp_2=ones((r_0+r_1,jdim))
  for i in xrange(idim):
    r_2[i,i]=fii_prim[i]
    #for j in xrange(i):
    for k in xrange(1,r_0+r_1):
      tsp_2[k,0:i]=tsp_2[k-1,0:i]*(t[i]-s[0:i])
    for p in range(r_0):
      for q in range(r_1):
        r_2[i,0:i]=r_2[i,0:i]+ff[p,q]*sp[p,0:i]*  \
          t1p[q,i]*tsp_2[p+q,0:i]
    r_2[0:i,i]=r_2[i,0:i]
  if p4a:
    print '    max(abs(PHI-PHI_rq))',abs(r_2-rf90).max()
    #, ' achieved at',abs(r-rf90).argmax(64)
    #plot3d(abs(r-rf90),'PHI and PHI_rq differences')
    #r2=r.copy()
    #rf902=rf90.copy()
    #for i in range(idim):
    #  r2[i,i]=0.0
    #  rf902[i,i]=0.0
    #print 'new offdiagonal max(abs(PHI-PHI_rq))',abs(r2-rf902).max()
  return r_2

def calculate_PHI_old(t,s,fii,fii_prim):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(t)
  jdim=len(s)
  if p4a:
    r1=zeros(idim*jdim)
    quad.quadruple.calculate_phi(len(t),len(s),t,s,r1)
    rf90=r1.reshape(idim,jdim)
  tim=time()
  r=zeros((idim,jdim))
  for i in xrange(idim):
    for j in xrange(jdim):
      if i==j:
        r[i,j]=fii_prim[i]
      else:
        r[i,j]=(fii[i]-fii[j])/(t[i]-s[j])
        if r[i,j]<0.: # IMPORTANT! to avoid nan-s in Acurly calculations...
          #print "FII<0: FII,fii_i, fii_j, t_i, s_j:",r[i,j],fii[i],fii[j],t[i],s[j]
          r[i,j]=0.
  tim=time()-tim
  print 'Time for calculating old PHI (in Python)',tim
  if p4a:
    print 'old max(abs(PHI-PHI_rq))',abs(r-rf90).max()
    #plot3d(abs(r-rf90),'PHI and PHI_rq differences')
    #r2=r.copy()
    #rf902=rf90.copy()
    #for i in range(idim):
    #  r2[i,i]=0.0
    #  rf902[i,i]=0.0
    #print 'old offdiagonal max(abs(PHI-PHI_rq))',abs(r2-rf902).max()
  return r

def Calculate_Dm_gamjt():
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  lenn=n-m+1
  if p4b==1:
    gam90=zeros(lenn)
    quad.quadruple.calculate_delta_m_gamjt(gam90)
    #print 'gam90:',gam90
    return gam90
  elif p4b>1:
    getcontext().prec=p4b
    dec_gamjt=Decimal_gamjt()
    #print 'dec_gamjt=',dec_gamjt
    dec_gg=Decimal_D1_m(dec_gamjt)
    #print 'dec_gg=',dec_gg
    #mx=Decimal(0)
    #for i in xrange(lenn):
    #  tmp=abs(dec_gg[i]-Decimal(str(gam90[i])))
    #  if (tmp):
    #    mx=tmp
    #print 'max diff between ',p4b,' and 4-fould Delta(gamjt) is:',mx
    gg=zeros(lenn)
    for i in xrange(lenn):
      gg[i]=float(dec_gg[i])
    #print 'gg:',gg
    return gg
  else:
    gamjt=Calculate_gamjt()
    #print 'gamjt=',gamjt
    gg=Calculate_D1_m(gamjt)
    #if p4b:
    #  print 'max diff between 2 and 4-fould Delta(gamjt) is:' \
    #         ,abs(gam90-gg).max()
    #  return gam90
    #else:
    return gg

def calculate_Acurly(t,s,fii,PHI,fii_prim): # (27)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b,lambda_0,lambda_1
  idim=len(t)
  jdim=len(s)
  if p4a:
    r1=zeros(idim*jdim)
    quad.quadruple.calculate_acurly(idim,jdim,t,s,r1)
    rf90=r1.reshape(idim,jdim)
  r=zeros((idim,jdim))
  for i in xrange(idim):
    r[i,:]=a(fii[i],fii[:])*PHI[i,:]**(-ny)*fii[:]**(-lambda_0)*(1.-fii[:])**(-lambda_1)*fii_prim[:]
  if p4a:
    print 'max diff between double and quadruple Acurly is:',abs(rf90-r).max()
    #return rf90
    return r
  else:
    return r

def calculate_Bcurly(t,s,fii,fii_prim):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(t)
  jdim=len(s)
  r=zeros((idim,jdim))
  for i in xrange(idim):
    r[i,:]=b(fii[i],fii[:])*fii_prim[:]
  return r

def Calculate_D2(w): # Calculate derivative of matrix on 2nd index
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(w[:,0])
  jdim=len(w[0,:])-1
  Dw=zeros((idim,jdim))
  for j in xrange(jdim):
    #Dw[:,j]=h*(w[:,j]-w[:,j+1])
    Dw[:,j]=w[:,j+1]-w[:,j]
  #print 'D2w is:',Dw
  return Dw

def Calculate_D2_m(w):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  for i in range(m):
    w=Calculate_D2(w)
  return w

def Calculate_D1(w): # Calculate difference of vector
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  dim=len(w)-1
  Dw=zeros((dim))
  for j in xrange(dim):
    Dw[j]=w[j+1]-w[j]
  return Dw

def Calculate_D1_m(w):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  for i in range(m):
    w=Calculate_D1(w)
  return w

def Decimal_D1_m(w):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  dim=len(w)-1
  for i in range(m):
    for j in xrange(dim):
      w[j]=w[j+1]-w[j]
  return w

def Calculate_Delta_jm(w):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(w[:,0])
  jdim=len(w[0,:])
  Dw=zeros((idim,jdim-m))
  m1m=(-1.)**m
  for k in range(m+1):
    c=m1m*comb(m,k)
    m1m=-m1m
    for j in range(jdim-m):
      Dw[:,j]=Dw[:,j]+c*w[:,j+k]
  return Dw

def Calculate_gamjt():
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  jjj=linspace(0.,n,n+1)
  m1m=(-1)**m
  g=ones(n+1)
  g[0:m/2+1]=m1m*(m/2.-jjj[0:m/2+1])**(m-ny)
  g[m/2+1:n+1]=(jjj[m/2+1:n+1]-m/2.)**(m-ny)
  return g

def Decimal_gamjt():
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if ny==0.5:
    jjj=arange(n+1) # integer sequence
    m1m=(-1)**m
    g=[Decimal(1)]*(n+1) # Decimal ones
    for i in range(m/2+1):
      #g[0:m/2+1]=m1m*(m/2-jjj[0:m/2+1])**(m-ny)
      g[i]=m1m*sqrt(Decimal(str(m/2.-jjj[i]))**int(2*(m-ny)))
    for i in range(m/2+1,n+1):
      g[i]=sqrt(Decimal(str(jjj[i]-m/2.))**int(2*(m-ny)))
    return g
  else:
    print 'ERROR: Decimal_gamjt written only for ny=0.5'

def Get_beta_ij(ii,jjm):
  import os
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  if not os.path.exists('PreCalc'):
    os.mkdir('PreCalc')
  filen='PreCalc/beta_ij_'+str(m)+'_'+str(n)+'_'+str(ny)+'.data'
  if os.path.exists(filen):
    #if file exist, read from there
    beta_ij1=fromfile(file=filen)
    print '  beta_ij read from file:',filen
    return beta_ij1.reshape((len(ii),len(jjm)-m))
  else:
    #otherwise calculate and save the result
    beta_ij=Calculate_beta_ij(ii,jjm)
    beta_ij.tofile(file=filen)
    print '  beta_ij written to file:',filen
    return beta_ij

def Calculate_beta_ij(ii,jjm): # (54)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(ii)
  jdim=len(jjm)
  if p4a:
    # Do the calculations in intel f90 through quadruple precision:
    g1=zeros(idim*(jdim-m))
    quad.quadruple.calculate_bet_ij(idim,jdim,ii,jjm,m,n,ny,g1)
    #print 'Py: g1:',g1
    ggf90=g1.reshape(idim,jdim-m)
    #print 'ggf90:',ggf90
  mult=ones(m)
  mult[m-1]=1.-ny
  for k in range(1,m): # 1,2,...,m-1
    mult[m-k-1]=mult[m-k]*(k+1-ny)
  h=1./n
  h1=h**(1.-ny)
  if p0b:
    if p4b:
      ggjdim=2*m
    else:
      ggjdim=jdim
    gg=zeros((idim,ggjdim))
    m1m=(-1.)**m
    # Prepare the multiplier for the row 1 and 4:
    fakt1=zeros(m)
    fakt4=zeros(m)
    fakt1[0]=m1m
    fakt4[0]=1.
    for k in range(1,m): # 1,2,...,m-1
      fakt1[k]=-fakt1[k-1]*k
      fakt4[k]=fakt4[k-1]*k
    for j in xrange(ggjdim):
      if jjm[j]<0: # 1st row
        for k in range(m): # 0,1,...,m-1
          gg[:,j]=gg[:,j]+(ii+m/2.)**(m-k-ny)/fakt1[k]*jjm[j]**k/mult[k]
      elif jjm[j]>n: # 4th row
        for k in range(m): # 0,1,...,m-1
          gg[:,j]=gg[:,j]+(n-ii-m/2.)**(m-k-ny)/fakt4[k]*(jjm[j]-n)**k/mult[k]
      else:
        for i in xrange(idim):
          if jjm[j]<ii[i]+m/2.: # 2nd row
            ff=m1m
            for k in range(1,m+1): # 1,2,...,m
              ff=ff/(k-ny)
            gg[i,j]=ff*(ii[i]-jjm[j]+m/2.)**(m-ny)
          else: # 3rd row
            ff=1.
            for k in range(1,m+1): # 1,2,...,m
              ff=ff/(k-ny)
            gg[i,j]=ff*(jjm[j]-ii[i]-m/2.)**(m-ny)
    # Now do the checking:
    #print 'ggf90=',ggf90
    #print 'gg=',gg
    gg=h1*Calculate_D2_m(gg)
    #___
    #for i in xrange(len(ii)):
    #  for j in xrange(len(jjm)-m):
    #    print 'i,j,beta_ij:',i,j,gg[i,j]
    #^^^
  bet0jt=h1/mult[0]*Calculate_Dm_gamjt()
  ggg=zeros((idim,jdim-m))
  if p0b:
    ggg[:,:m-1]=gg[:,:m-1] # left m_1 columns - need to be calculated!!!
  #pdb.set_trace() # switch on debugger ################################
  ggg[::-1,:n-1:-1]=ggg[:,:m-1] #upside-down sym.with right m_1 cols
  #  for comment  :  (Reversing an array can be done by: a[::-1])
  #                  (a[len(a)-1:-1:-1] is wrong!)
  ## the placement rule is: ggg[i,j]=,bet0jt[abs(i-j+(m_1-m_0))]
  # the placement rule is: ggg[i,j]=,bet0jt[abs(i-j+(m-1-m_0))]
  #     for all i and m-1 <= j <= n+1
  #         and abs(i-j+(m-1-m_0)) <= n-m
  #
  # Let's try to do it by vector operations:
  # place the beta_{0j} to it's location.
  #   note that in our indexing, beta_{0j} in the paper corresponds to
  #     beta_{0+m_0,j+m_1}
  ggg[m_0,m-1:n]=bet0jt[0:n-m+1]
  # by symmetry, add also the column down (without the first el.) :
  ggg[m_0+1:idim-m_0,m-1]=ggg[m_0,m:n]
  # row(s) above (with the move to the left and shrinking)
  for i in xrange(1,m_0+1):
    ggg[m_0-i,m-1:n-i]=bet0jt[i:n-m+1]
  # the corner element(s) not repr (need to be calculated!!!):
  #                        forming at bottom left first
  if p0b:
    ggg[idim-m_0:idim,m-1]=gg[idim-m_0:idim,m-1]
    # the same to up right by symmetry:
    ggg[m_0-1::-1,n-1]=gg[idim-m_0:idim,m-1]
  #pdb.set_trace() # switch on debugger ################################
  # now, under row m_0, the same vector is given with a shift to the right
  #   (length) shrinking correspondingly
  #pdb.set_trace() # switch on debugger ################################
  for i in xrange(1,n-m+1):
    #print 'aaa',m_0+i,',',arange(m-1+i,n),'=',arange(0,n-m+1-i)
    ggg[m_0+i,m-1+i:n]=bet0jt[0:n-m+1-i]
    # and by symmetry, add also the coulmn down:
    #print 'bbb',arange(m_0+1+i,idim),',',m-1+i,'=',arange(m_0+i,idim-1),',',m-1+i-1
    ggg[m_0+1+i:idim,m-1+i]=ggg[m_0+i:idim-1,m-1+i-1]
    # and by symm the missing part up on first rows:
    if i<m_0:
      #print '555',arange(m_0-1-i,-1,-1),',',n-1-i,'=',arange(idim-m_0,idim-i),',',m-1
      ggg[m_0-1-i::-1,n-1-i]=ggg[idim-m_0:idim-i,m-1]
  #print 'ggg=',ggg
  if p0b and p4b:
    #for i in xrange(len(ii)):
    #  for j in xrange(len(jjm)-m):
    #    print 'i,j,beta_ij:',i,j,abs(ggg[i,j]-gg[i,j])
    print 'max diff between double beta_ij and new one is:', \
    abs(ggg[:,:ggjdim-m]-gg).max()
  if p4a and p4b:
    print 'max diff between double and quadruple beta_ij is:',abs(ggf90-ggg).max()
    #return ggf90
    return ggg
  elif p4a:
    print 'max diff between new  and quadruple beta_ij is:',abs(ggf90-ggg).max()
    return ggf90
  elif p4b:
    return ggg
  else:
    return gg

def Calculate_gamma_j0(jjm):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  dim=len(jjm)
  g=zeros((dim))
  for i in xrange(dim):
    if jjm[i]>n:
      g[i]=1.*jjm[i]**m-(jjm[i]-1.*n)**m
      ##g[i]=(1.*jjm[i])**m-(jjm[i]-1.*n)**m
    elif jjm[i]>=0:
      g[i]=1.*jjm[i]**m
      ##g[i]=(1.*jjm[i])**m
  return g      

def Calculate_tau(alpha_p,beta_ij,a_ij,beta_j0,b_ij): # (49)
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  idim=len(b_ij[:,0])
  jdim=len(b_ij[0,:])
  if p4t:
    # Do the calculations in intel f90 through quadruple precision:
    r1=zeros(idim*jdim)
    quad.quadruple.calculate_tau(idim,jdim,beta_j0,b_ij,r1)
    tauf90=r1.reshape(idim,jdim)
    #print 'tauf90=',tauf90
  #
  tau=zeros(shape(a_ij))
  dim=len(a_ij[:,0])
  for k in xrange(dim):
    for j in range(k,k+2*(m_1-1)+1):
      tau[:,k]=tau[:,k]+alpha_p[j-k]*(beta_ij[:,j]*a_ij[:,k]+beta_j0[k]*b_ij[:,k])
  #print 'tau=',tau
  if p4t:
    print 'max diff between double and quadruple tau is:',abs(tauf90-tau).max()
    #plot3d(1.0e07*abs(tau-tauf90),'1e07*tau_ij')
    return tauf90
  else:
    return tau

def L2Norm(x):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return sqrt(sum(x*x))

def WeightedNorm(x,t):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return ((t*(1.-t))**1.5*x).max()
  #return max((t*(1.-t))**1.5*x)

def FiiPrimNorm(fii_prim,x):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return (fii_prim*x).max()

def FiiPrimNyNorm(fii_prim,x):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  return (fii_prim**ny*x).max()

def Calculate_phi_inverse(x,eps,r_0,r_1,c_star):
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  dim=len(x)
  tk=ones(dim)
  tk=tk/2.
  delta=1.
  for k in xrange(100):
    fi=calculate_fii(tk,r_0,r_1,c_star)
    fip=calculate_fii_prim(tk,r_0,r_1,c_star)
    #print k,delta,' inv is:',tk,'\n'
    tk=tk-(fi-x)/fip
    if delta<eps*0.01:
      break
    else:
      delta=delta/dim
  return tk

def doplot(y,xlabel,ylabel): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  g.ylabel(ylabel)
  g.xlabel(xlabel)
  g.title(ylabel)
  #g('set multiplot')
  #g.plot(y)
  #g('set data style line')
  #g.plot(y)
  g('set data style linespoints')
  g.plot(y)
  raw_input("click enter <--' ")

def twoplot1(y1,y2,xlabel,ylabel): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  g.ylabel(ylabel)
  g.xlabel(xlabel)
  g.title(ylabel)
  g('set multiplot')
  g.plot(y1)
  #g('set data style linespoints')
  g('set data style line')
  #g('set data style line1color blue')
  g.plot(y2)
  raw_input("click enter <--' ")

def twoplot(x,y1,y2,xlabel,ylabel,title,text1,text2): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  g.ylabel(ylabel)
  g.xlabel(xlabel)
  g.title(title)
  g('set data style line')
  d1 = Gnuplot.Data(x, y1,
                   title=text1,
                   with='points 3 1')
  d2 = Gnuplot.Data(x, y2,
                   title=text2,
                   with='lines 1 1')
  g.plot(d1,d2)
  raw_input("click enter <--' ")

def eigenplot(eigs,xlabel,ylabel,title,text): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  g.ylabel(ylabel)
  g.xlabel(xlabel)
  g.title(title)
  g('set data style line')
  re=zeros(len(eigs))
  re=real(eigs)
  im=zeros(len(eigs))
  im=imag(eigs)
  d = Gnuplot.Data(re,im,
                   title=text,
                   with='points 3 1')
  g.plot(d)
  raw_input("click enter <--' ")

def pstwoplot(x,y1,y2,xlabel,ylabel,title,text1,text2): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  g.ylabel(ylabel)
  g.xlabel(xlabel)
  g.title(title)
  g('set data style line')
  d1 = Gnuplot.Data(x, y1,
                   title=text1,
                   with='points 3 1')
  d2 = Gnuplot.Data(x, y2,
                   title=text2,
                   with='lines 1 1')
  g.plot(d1,d2)
  g.hardcopy('twoplot.ps', enhanced=1, color=1)

def plot3d(f,label): # for plotting 1D functions
  import Gnuplot, Gnuplot.funcutils
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  g = Gnuplot.Gnuplot(debug=0)
  g.clear()
  x = arange(len(f[:,0]))
  y = arange(len(f[0,:]))
  g('set parametric')
  g('set data style lines')
  g('set hidden')
  g('set contour base')
  g.title(label)
  g.xlabel('x')
  g.ylabel('y')
  g.splot(Gnuplot.GridData(f,x,y, binary=0))
  raw_input("click enter <--' ")

#Main part starts here: ================================
def calculations(n_in,nn,m_in,m_0_in,m_1_in,rr):
  import os
  global n,m,m_0,m_1,ny,p4a,p4t,p4f,p4h,p4r,p4b,p0b
  n=n_in
  m=m_in
  m_0=m_0_in
  m_1=m_1_in
  totime=time()
  r_0=rr # or 9
  r_1=rr # or 9
  c_star=calculate_c_star(r_0,r_1)
  #print 'c_star=',c_star
  h=1./n
  tn=n-m_1+m_0+1
  t=linspace((-m_0+m/2.)*h,(n-m_1+m/2.)*h,tn)
  #print 't=',t
  s=linspace((-m_0+m/2.)*h,(n-m_1+m/2.)*h,tn)
  ii=arange(-m_0,n-m_1+1)
  jj=arange(-m+1,n)
  jjm=arange(-m+1,n+m)
  gamma_j0=Calculate_gamma_j0(jjm)
  #doplot(gamma_j0,'j','gamma_j0')
  beta_j0=h*Calculate_D1_m(gamma_j0)/factorial(m) # (52) # can be replaced by h
                                                       #   in (0,1)...
                # beta_j0 must be positive!!
  #doplot(beta_j0,'j','beta_j0')
  #for i in range(10):
  #  print '%d beta_j0=%23.18e %23.18e'%(i,beta_j0[i],beta_j0[len(beta_j0)-i-1])
  # pdb.set_trace() # switch on debugger ################################
  print 'starting calculation of beta_ij...'
  btim=time()
  beta_ij=Get_beta_ij(ii,jjm)
  print 'time for beta_ij:',time()-btim
  fii=calculate_fii(s,r_0,r_1)
  wfii=where(fii>1.)
  if size(wfii)>0:
    print 'overshoot where(fii>1.):',wfii
    fii[wfii]=1.
  #doplot(fii,'j','fii')
  fii_prim=calculate_fii_prim(s,r_0,r_1,c_star)
  #doplot(fii_prim,'j','fii_prim')
  PHI=Get_PHI(t,s,r_0,r_1,fii_prim)
  #plot3d(PHI,'PHI')
  a_ij=calculate_Acurly(t,s,fii,PHI,fii_prim)
  #print 'a_ij=',a_ij
  #plot3d(a_ij,'a_ij')
  b_ij=calculate_Bcurly(t,s,fii,fii_prim)
  #print 'b_ij=',b_ij
  #plot3d(b_ij,'b_ij')
  g_i=f(fii)
  #doplot(g_i,'j','g_i')
  alpha_p=calculate_alpha_p()
  tau_ij=Calculate_tau(alpha_p,beta_ij,a_ij,beta_j0,b_ij)
  #plot3d(tau_ij,'tau_ij')
  #A=identity(nn)-tau_ij # (the case with known solution)
  A=50.0*identity(nn)-tau_ij # for testing with ny>0.6 we use multiplier to
                             #   avoid very bad systems with lots of neg.
                             #   eigenvalues...
  #plot3d(A,'matrix A')
  bb=g_i
  xx=zeros(nn)
  tim=time();info=linalg.iterative.gmres(A,bb,nn,xx,1e-16,None,'d')
  tim=time()-tim
  totime=time()-totime
  #### To see the eigenvalues of A:
  #eigs=linalg.eigvals(A)
  #print '1.0-eigs:',1.0-sort(real(eigs))
  #print 'eigs:',sort(real(eigs))
  #eigenplot(eigs,'i','eig(A)','Matrix A eigenvalues','eigenvalues ')
  ####
  if info[1]<>0:
    print '###### exit code:',info[1]
  nits=info[2]
  #io.write_array("sol_n"+str(n)+"_m"+str(m)+"_r"+str(rr)+".vec",xx)
  #io.mmwrite is better:
  xxmat=zeros((len(xx),1))
  xxmat[:,0]=xx
  if not os.path.exists('tmp'):
    os.mkdir('tmp')
  io.mmwrite("tmp/sol_n"+str(n)+"_m"+str(m)+"_r"+str(rr),xxmat)
  ##io.write_array("tmp/t_n"+str(n)+"_m"+str(m)+".vec",t)
  cchk=dot(A,xx)
  #print 'Norm of solution error:',sum(abs(cchk-bb))
  exact_sol_fii=known_solution(fii)
  #doplot(abs(exact_sol_fii-xx),'t','sol differences')
  #doplot(abs(exact_sol_fii),'t','known sol(fii)')
  #exsmat=zeros((len(exact_sol_fii),1))
  #exsmat[:,0]=exact_sol_fii
  #io.mmwrite("tmp/exs_n"+str(n)+"_m"+str(m)+"_r"+str(rr),exsmat)
  l2n_fii=L2Norm(exact_sol_fii-xx)
  maxn_fii=max(abs(exact_sol_fii-xx))
  diffsums_fii=sum(abs(exact_sol_fii-xx))
  wn_fii=WeightedNorm(exact_sol_fii-xx,t)
  fpn=FiiPrimNorm(fii_prim,exact_sol_fii-xx)
  fpnn=FiiPrimNyNorm(fii_prim,exact_sol_fii-xx)
  ##twoplot1(exact_sol_fii,xx,'s','known and calculated solution ')
  #tt=linspace((-m_0+m/2)*h,(n-m_1+m/2)*h,tn)
  ###pstwoplot(t,exact_sol_fii,xx,'t_i','Functions','n='+str(n)+'     r_0=r_1='+str(rr),
  ###        'exact fii(u):','calculated v_n:')
  ##twoplot(t,exact_sol_fii,xx,'t_i','Functions',\
  ##        'n='+str(n)+'r_0=r_1='+str(rr),\
  ##        'exact u(fii):','calculated v_n:')
  #twoplot(t[3*len(t)/4:],exact_sol_fii[3*len(t)/4:],xx[3*len(t)/4:],'s','known and calculated solution ')
  # Paragraph 4.4:
  #xxx=Calculate_phi_inverse(xx,maxn_fii,r_0,r_1,c_star)
  #exact_sol=known_solution(s)
  #print 'xxx-u is:', xxx-exact_sol
  #twoplot(t,exact_sol,xxx,'t_i','Functions',\
  #        'n='+str(n)+' r_0=r_1='+str(rr),\
  #        'exact u:','calculated v_n:')
  #l2n=L2Norm(exact_sol-xxx)
  #maxn=max(abs(exact_sol-xxx))
  #diffsums=sum(abs(exact_sol-xxx))
  #return (l2n_fii,maxn_fii,diffsums_fii,l2n,maxn,diffsums,totime,tim,nits)
  return (l2n_fii,maxn_fii,diffsums_fii,wn_fii,totime,tim,nits,fpn,fpnn)

if __name__ == "__main__":
  #Parameters: ------------------------------------------
  ny=0.9
  lambda_0=0.
  lambda_1=0.
  n=8
  m=7
  if m/2*2==m: # (10),(19):
    m_0=(m-2)/2 # m is even
    m_1=m/2+1
  else:
    m_0=(m-1)/2 # m is odd
    m_1=(m+1)/2
  #------------------------------------------------------
  p4a=0 # calculation of (fii,fii_prim,Acurly,beta_ij) in quadruple?
  p4t=0 # calculation of tau in quadruple?
  p4f=0 # calculation of fii,fii_prim in quadruple?
  p4h=0 # calculation of rHs in quadruple?
  p4r=0 # calculation of exact result in quadruple?
  p4b=1 # calculation of Delta_m_gamma_j_tilde in quadruple?
  #p4b=64 # >1                                   in p4b Decimal digits?
  p0b=1 # 0 -- set all the not-repeating beta_ij to zero!
  #------------------------------------------------------
  #for i in range(1):
  #------------------------------------------------------
  #------------------------------------------------------
  n=256
  rr=10
  if m/2*2==m: # (10),(19):
    nn=n-1 # after (65) -- the dimension of the lin. system
  else:
    nn=n
  ans=calculations(n,nn,m,m_0,m_1,rr)
  exit(0)
  #------------------------------------------------------
  #------------------------------------------------------
  tb1s=4
  tb1e=9
  tb2s=2
  tb2e=10
  table1=zeros((tb1e-tb1s,tb2e-tb2s))
  table2=zeros((tb1e-tb1s,tb2e-tb2s))
  #for i in range(9):
  for i in range(tb1s,tb1e):
    n=n*2
    if m/2*2==m: # (10),(19):
      nn=n-1 # after (65) -- the dimension of the lin. system
    else:
      nn=n
    for rr in range(tb2s,tb2e):
      print '===========> n =',n,'====> m (_0 _1)=',m,'(',m_0,m_1,')',' ====> rr =',rr,\
      ' <==============='
      ans=calculations(n,nn,m,m_0,m_1,rr)
      exit(0)
      print '--->                  Max diff. from known sol: %8.2e' % ans[1]
      #print '---> WeightedNorm of diff from known sol(fii): %9.3e' % ans[3]
      #print '---> FiiPrimNorm of diff from known sol(fii): %9.3e' % ans[7]
      print '---> FiiPrimNyNorm of diff from known sol(fii): %8.2e' % ans[8]
      #print '---> WeightedNorm of diff from known sol(fii): %23.18e' % ans[3]
      print '--->              Solve time (#it): %10.4e (%d)' % (ans[5],ans[6])
      print '--->                    Total time: %10.4e' % ans[4]
      #print '---> L2Norm of diff from known sol: %9.3e' % ans[0]
      #print '        (Sum of diffs): %9.3e' % ans[2]
  #    print '--->              Solve time (#it): %10.4e (%d)' % (ans[4],ans[5])
  #    print '--->                    Total time: %10.4e' % ans[3]
  
      table1[i-tb1s,rr-tb2s]=ans[1]
      table2[i-tb1s,rr-tb2s]=ans[8]
  plot3d(table1,'Max diff')
  #io.mmwrite("table1",table1)
  #io.mmwrite("table2",table2)