#Copy of the romberg quadrature class from the PS3 solution script
from polint import polint
from ClimateUtilities import *
class BetterTrap:
    def __init__(self,f,interval,nstart):
        self.f = f
        self.n = nstart
        self.interval = interval
        self.integral = self.dumbTrap(nstart)
    def dumbTrap(self,n):
        a = self.interval[0]
        b = self.interval[1]
        dx = (b-a)/n
        sum = dx*(self.f(a)+self.f(b))/2.
        for i in range(1,n):
            x = a+i*dx
            sum = sum + self.f(x)*dx
        return sum
    def refine(self):
        #Compute the sum of f(x) at the
        #midpoints between the existing intervals.
        #To get the refinement of the trapezoidal
        #rule sum we just add this to half the
        #previous result
        sum = 0.
        a = self.interval[0]
        b = self.interval[1]
        dx = (b-a)/self.n
        #Remember: n is the number of subintervals,
        #not the number of endpoints. Therefore we
        #have one midpoint per subinterval. Keeping that
        #in mind helps us get the range of i right in
        #the following loop
        for i in range(self.n):
            sum = sum + self.f(a+(i+.5)*dx)*(dx/2.)
        #The old trapezoidal sum was multiplied by
        #the old dx.  To get its correct contribution
        #to the refined sum, we must multiply it by .5,
        #because the new dx is half the old dx
        self.integral = .5*self.integral + sum
        #
        #Update the number of intervals
        self.n = 2*self.n
class romberg:
    def __init__(self,f,interval,nstart):
        #Make a trapezoidal rule integrator
        self.trap = BetterTrap(f,interval,nstart)
        #We keep lists of all our results, for doing
        #Romberg extrapolation
        self.nList = [nstart]
        self.integralList = [self.trap.integral]
    def refine(self):
        self.trap.refine()
        self.integralList.append(self.trap.integral)
        self.nList.append(self.trap.n)
    #
    #Use a __call__ method to return the result. The
    #__call__ method takes an optional argument, which is
    #the number of additional refinement steps to take before computing
    #the result. This may not be the ideal way to set up the behavior
    #of the object. I'm still thinking of other designs. It might
    #be better to specify the "accuracy" and optionally the
    #maximum number of refinements, and then have the method
    #do enough refinements to meet the accuracy criterion. The
    #intermediate results could still be retrieved from self.integralList
    #and self.nList for exploratory purposes
    def __call__(self,nRefine = 0):
        for i in range(nRefine):
            self.refine()
        dx = [1./(n*n) for n in self.nList]
        return polint(dx,self.integralList,0.)


#Problem 1------------
#
#Define the integrand
import math
def f(x):
    return math.exp(math.cos(x))

#Now define a function that returns the integral from 0 to t
def P(t):
    quad = romberg(f,[0.,t],10)
    A = quad(5) #Probably overkill. You could check for accuracy
    return A*math.exp(-math.cos(t))

#Put it in a curve and plot it
t = [i*math.pi/10. for i in range(100)]
c = Curve()
c.addCurve(t,'t','Time')
c.addCurve([P(tt) for tt in t],'P','Population')
plot(c)

#Problem 2:
A = math.sqrt(3.)
def x(t):
    return math.exp(-t/2.)*(math.cos(math.sqrt(3.)*t/2.) +
        A*math.sin(math.sqrt(3.)*t/2.) )
def v(t):
    a = -1./2. + A*math.sqrt(3.)/2.
    b = -1./2. - math.sqrt(3.)/2.
    return math.exp(-t/2.)*(a*math.cos(math.sqrt(3.)*t/2.) +
        b*math.sin(math.sqrt(3.)*t/2.) )

c = Curve()
t = [i*math.pi/10. for i in range(80.)]
c.addCurve([x(tt) for tt in t])
c.addCurve([v(tt) for tt in t])
plot(c)

def x1(t):
    return (1.+2.*t)*math.exp(-t)
def v1(t):
    return (1. - 2.*t)*math.exp(-t)
t = [i*math.pi/10. for i in range(80.)]
c = Curve()
c.addCurve([x1(tt) for tt in t])
c.addCurve([v1(tt) for tt in t])
plot(c)

