plotutils.py

Sanchez David, 11/05/2014 09:45 AM

Download (16.5 KB)

 
1
# ==========================================================================
2
# This script provides a number of functions that are useful for handling
3
# CTA observations.
4
#
5
# Copyright (C) 2011-2014 David Sanchez Michael Mayer Rolf Buehler Juergen Knoedlseder
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU General Public License as published by
9
# the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
#
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU General Public License for more details.
16
#
17
# You should have received a copy of the GNU General Public License
18
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
#
20
# ==========================================================================
21
import ctools as ct
22
import gammalib as gl
23
from math import log10,pow,sqrt
24
import os
25

    
26
try:
27
        import matplotlib.pyplot as plt
28
        import matplotlib.gridspec as gridspec
29
        has_matplotlib = True
30
except:
31
        has_matplotlib = False
32

    
33
try:
34
        import aplpy
35
        has_aplpy = True
36
except:
37
        has_aplpy = False
38

    
39
class options(object):
40
    def __init__(self):
41
        self.fmt         = 'o'
42
        self.color       = BLUE
43
        self.linecolor   = 'black'
44
        self.markercolor = self.color
45
        self.tslim       = 4.0
46
        self.elinewidth  = 2.0
47
        self.linewidth   = 2.0
48
        self.linestyle   = '--'
49
        self.capsize     = 5.0
50
        self.markersize  = 5.0
51
        self.limlength   = 0.4 
52
        self.label       = "_nolegend_"
53

    
54
import analysisutils 
55

    
56
class SpectralBase(object):
57
    def __init__(self,sed=True,sed_factor=1.):
58
        self.m_sed_factor = sed_factor
59
        self.m_sed        = sed
60
        
61
    def _convertSED(self,tab,energy):
62
        """convert the a table storing dN/dE in sed with right units"""
63

    
64
        if len(tab) != len(energy):
65
            self.error("input table and energy table does not match")
66
            return
67
        if self.m_sed:
68
            for i in xrange(len(tab)):
69
                tab[i] *= energy[i]**2*self.m_sed_factor
70
        return tab
71

    
72
BLUE="#337EF5"
73
GREY='grey'
74
GREEN = '#21851D'
75
RED = '#BD0B0B'
76
YELLOW = '#D4CF44'
77
ORANGE = '#D19534'
78

    
79

    
80
class residuals(analysisutils.base,options,SpectralBase):
81
    def __init__(self,spectrum,datapoint,sed,sed_factor,eunit="TeV"):
82
        super(residuals,self).__init__()
83
        options.__init__(self)
84
        SpectralBase.__init__(self,sed,sed_factor)
85

    
86
        # load the gammalib spectrum object
87
        self.spectrum  = spectrum
88
        self.SetDataPoint(datapoint)
89
        
90
        self.xerrors = True
91
        self.refUnit = eunit     # reference unit for conversion
92
        
93
    def SetDataPoint(self,datapoint):
94
        self.m_point_info = datapoint
95
        self.Npt = len(self.m_point_info["dnde"]["value"])
96

    
97
    def _plt_residuals(self):
98
        th_val = []
99
        index  = []
100
        ed_y   = []
101
        eu_y   = []
102
        y      = []
103
        
104
        Energy_list = {"MeV":1,"GeV":1e-3,"TeV":1e-6} #TODO
105
        for i in range(self.Npt):
106
            if self.m_point_info["TS"][i] >= self.tslim:
107
                th_val.append(self.spectrum.eval(gl.GEnergy(self.m_point_info["ener"]["value"][i],self.m_point_info["ener"]["unit"]),gl.GTime(0)))
108
                index.append(i)
109
            else:
110
                th_val.append(0)
111
        if self.m_sed:
112
            self._convertSED(th_val,self.m_point_info["ener"])
113
        for i in index:
114
                y.append((self.m_point_info["dnde"]["value"][i] - th_val[i]) / th_val[i])
115
                ed_y.append(self.m_point_info["dnde"]["ed_value"][i]/th_val[i])
116
                eu_y.append(self.m_point_info["dnde"]["eu_value"][i]/th_val[i])
117
                
118
        if self.xerrors:
119
            plt.errorbar(self.m_point_info["ener"]["value"], y, xerr=[self.m_point_info["ener"]["ed_value"],self.m_point_info["ener"]["ed_value"]],yerr=[ed_y,eu_y],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
120
        else:
121
            plt.errorbar(self.m_point_info["ener"]["value"], y, yerr=[ed_y,eu_y],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
122
        
123
        plt.axhline(0.0,color=self.linecolor,lw=self.linewidth,ls=self.linestyle)
124

    
125

    
126
class ulimgraph(analysisutils.base,options,SpectralBase):
127
    def __init__(self,datapoint,sed,sed_factor,eunit='TeV'):
128
        super(ulimgraph,self).__init__()
129
        #~ super(ulimgraph,self).__init__()
130
        options.__init__(self)
131
        SpectralBase.__init__(self,sed,sed_factor)
132
        
133
        self.xerrors=True  
134
        self.nlims = -1
135
        self.refUnit = eunit
136
        self.SetDataPoint(datapoint)
137
    
138
    def SetDataPoint(self,datapoint):
139
        self.m_point_info = datapoint
140
        self.Npt = len(self.m_point_info["dnde"]["value"])
141
       
142
    def _plt_points(self):
143
        results = analysisutils.ResultsSorage()
144
        
145
        Energy_list = {"MeV":1,"GeV":1e-3,"TeV":1e-6} #TODO
146
        for i in range(self.Npt):
147

    
148
            self.m_point_info["ener"]["value"][i]*=Energy_list[self.refUnit]
149
            self.m_point_info["ener"]["ed_value"][i]*=Energy_list[self.refUnit]
150
            self.m_point_info["ener"]["eu_value"][i]*=Energy_list[self.refUnit]
151
            self.m_point_info["ener"]["unit"] = self.refUnit
152

    
153
            if self.m_point_info["TS"][i] >= self.tslim:
154

    
155
                if self.m_point_info["dnde"]["ed_value"][i]>self.m_point_info["dnde"]["value"][i]:
156
                    self.warning("\tflux error: "+str(self.m_point_info["dnde"]["ed_value"][i])+" is larger than flux value "+str(self.m_point_info["dnde"]["ed_value"][i]))
157
                    self.warning("\tTS value is however TS="+str(self.m_point_info["TS"][i])+", weird!")
158
            else :
159
                self.m_point_info["dnde"]["ed_value"][i] = 0
160
                self.m_point_info["dnde"]["eu_value"][i] = 0
161
            
162
        if self.xerrors:
163
            plt.errorbar(self.m_point_info["ener"]["value"], self.m_point_info["dnde"]["value"], xerr=[self.m_point_info["ener"]["ed_value"], self.m_point_info["ener"]["eu_value"]],yerr=[self.m_point_info["dnde"]["ed_value"], self.m_point_info["dnde"]["eu_value"]],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
164
        else:
165
            plt.errorbar(self.m_point_info["ener"]["value"], self.m_point_info["dnde"]["value"],yerr=[results["dnde"]["ed_value"],results["dnde"]["eu_value"]],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
166
        
167
    def _plt_limits(self):
168
        ed_flux = []
169
        eu_flux = []
170
        flux    = []
171
        ener    = []
172
        
173
        drawn_lims = 0
174
        Energy_list = {"MeV":1,"GeV":1e-3,"TeV":1e-6} #TODO
175
        for i in range(self.Npt):
176
            if self.m_point_info["TS"][i] < self.tslim:
177
                self.info("Plot an upper limit for E = "+self.m_point_info["ener"]["value"][i]+" with a TS = "+self.m_point_info["TS"][i])
178
                if drawn_lims >= self.nlims and self.nlims!=-1:
179
                    break
180

    
181
                ed_flux.append(self.limlength*self.m_point_info["ulim_dnde"][i])
182
                eu_flux.append(0.0)
183
                flux.append(self.m_point_info["ulim_dnde"][i])
184
                drawn_lims+=1
185
                
186
        # if self.xerrors:
187
            # plt.errorbar(self.m_point_info["ener"]["value"], self.m_point_info["dnde"]["value"], xerr=[self.m_point_info["ener"]["ed_value"], self.m_point_info["ener"]["eu_value"]],yerr=[self.m_point_info["dnde"]["ed_value"],self.m_point_info["dnde"]["eu_value"]],fmt=self.fmt,markersize=0.0,elinewidth=self.elinewidth,lolims=True,capsize=self.capsize,label='_nolegend_',mfc = self.color,ecolor=self.color,ms =self.markersize)      
188
        # else:
189
        plt.errorbar(ener, flux, yerr=[ed_flux, eu_flux],fmt=self.fmt,markersize=0.0,elinewidth=self.elinewidth,lolims=True,capsize=self.capsize,label='_nolegend_',mfc = self.color,ecolor=self.color,ms =self.markersize)      
190
            
191

    
192
class LightCurvePlotter(analysisutils.base,options):
193
    """ Class to plot the Light curve"""
194
    def __init__(self,srcname,datapoint):
195
        super(LightCurvePlotter,self).__init__()
196
        options.__init__(self)
197
        self.m_name = srcname
198
        self.m_data = datapoint #data point in a Results Storage class
199

    
200
    def draw(self):
201
        if not(has_matplotlib):
202
            self.warning("matplotlib module not found, can draw")
203
            return
204

    
205
        time     = []
206
        dtime    = []
207
        flux     = []
208
        dflux_ed = []
209
        dflux_eu = []
210
        for i in range(len(self.m_data["time"]["tmin_value"])):
211
            time.append((self.m_data["time"]["tmin_value"][i]+self.m_data["time"]["tmax_value"][i])/2)
212
            dtime.append((self.m_data["time"]["tmax_value"][i]-self.m_data["time"]["tmin_value"][i])/2)
213
            flux.append(self.m_data["Iflux"]["value"][i])
214
            dflux_eu.append(self.m_data["Iflux"]["eu_value"][i])
215
            dflux_ed.append(self.m_data["Iflux"]["ed_value"][i])
216

    
217
        plt.figure("results",figsize=(12,7),edgecolor="w")
218
        unit = self.m_data["time"]["unit"]
219
        plt.xlabel("Time ["+ unit+"]",fontsize=15)
220
        unit = self.m_data["Iflux"]["unit"]
221
        plt.ylabel("E$^{2}$ dN/dE ["+unit+"]",fontsize=15)
222

    
223
        print time
224
        print dtime
225
        print flux
226
        print dflux_ed
227
        print dflux_eu
228
        plt.errorbar(time,flux,xerr=[dtime,dtime],yerr=[dflux_ed,dflux_eu],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
229
        plt.show()
230

    
231
class SpectrumPlotter(analysisutils.base,SpectralBase):
232
    """ Class to compute and plot the spectra (i.e best fit model and butterfly).
233
        if data points are provided, only change applied  is energy unit convertion to match the butterfly one
234
        (i.e. no dN/dE to SED conversion).
235
        It is currently assumed that the unit is MeV for the energy of the data point. this will be change 
236
        once the I/O container format will be defined."""
237
    def __init__(self,srcname,like,datapoint,emin=0.1,emax=100,energy = "TeV",SED = True ,npt = 50):
238
        super(SpectrumPlotter,self).__init__()
239

    
240
        self.m_ebound   = [emin,emax]
241
        self.m_name     = srcname
242
        self.m_like     = like
243
        self.m_spectral = like.obs().models()[srcname].spectral()
244
        self.m_energy   = []
245
        self.m_flux     = []
246
        self.m_error    = []
247
        self.m_but      = []
248
        self.m_enebut   = []
249
        self.m_npt      = npt
250
        self.m_covar    = gl.GMatrix(1,1)
251
        self.covar()
252
        #~ this part take care of the style : SED or differential flux
253
        #~ in the case of SED, the y-unit is erg/cm2/s
254
        self.m_sed = SED
255
        self.eunit = energy
256
        self.m_sed_factor = 1.
257
        self._validateUnits()
258
        SpectralBase.__init__(self,self.m_sed,self.m_sed_factor)
259

    
260
        #Assume MeV for the energy #TODO
261
        self.points = ulimgraph(datapoint,self.m_sed,gl.MeV2erg,energy)
262
        self.residuals = residuals(self.m_spectral,datapoint,self.m_sed,gl.MeV2erg,energy)
263

    
264
    def _validateUnits(self):
265
        Energy_list = {"MeV":1,"GeV":1e3,"TeV":1e6}
266
        if not(Energy_list.has_key(self.eunit)):
267
            self.warning("Change energy unit to TeV")
268
            self.eunit = "TeV"
269
        
270
        if self.m_sed:
271
            self.m_sed_factor = gl.MeV2erg*Energy_list[self.eunit]**2
272

    
273
    def covar(self):
274
        # Get covariance matrix if fit has converged
275
        try:
276
            fullcovar = self.m_like.obs().function().curvature().invert()
277
        except:
278
            self.warning("Covariance matrix not determined successfully")
279
            return
280
        #~ get the matrix element of the source of interest
281
        #~ 1) get the right element indices
282
        idx = 0
283
        par_index_map = []
284
        for m in self.m_like.obs().models() :
285
            if m.name()!= self.m_name:
286
                idx += m.size()
287
                continue
288
            idx += m.spatial().size()
289
            for par in m.spectral():
290
                if par.is_free():
291
                    par_index_map.append(idx)
292
                    idx += 1
293

    
294
        #~ 2) get the elemets and store them in a matrix
295
        self.m_covar = gl.GMatrix(len(par_index_map),len(par_index_map))
296
        i = 0
297
        for xpar in par_index_map:
298
            j = 0
299
            for ypar in par_index_map:
300
                self.m_covar[i,j] = fullcovar[xpar,ypar]
301
                j += 1
302
            i += 1
303
        self.success("Covariance Matrix succesfuly computed")
304

    
305
    def _makespectrum(self):
306
        for i in xrange(self.m_npt):
307
            #Compute the energy
308
            lene = log10(self.m_ebound[0])+log10(self.m_ebound[1])*i/(self.m_npt-1)
309
            self.m_energy.append(pow(10,lene))
310
            #Compute the flux
311
            self.m_flux.append(self.m_spectral.eval(gl.GEnergy(pow(10,lene),self.eunit),gl.GTime(0)))
312

    
313
            #Compute the gradients
314
            if self.m_covar.size() != 1:
315
                self.m_spectral.eval_gradients(gl.GEnergy(pow(10,lene), self.eunit),gl.GTime(0))
316
                #store them into a GVector
317
                Derivative = gl.GVector(self.m_covar.size()/2)
318
                j = 0
319
                for par in self.m_spectral:
320
                    if par.is_free():
321
                        Derivative[j] = par.factor_gradient()
322
                    j += 1
323
                #~ computed the error
324
                self.m_error.append(sqrt(Derivative * (self.m_covar*Derivative)))
325

    
326
        #~ convert the flux in sed if asked
327
        self.m_flux = self._convertSED(self.m_flux,self.m_energy)
328
        if self.m_covar.size() != 1:
329
            self.m_error = self._convertSED(self.m_error,self.m_energy)
330
            self._makebutterfly() # make the butterfly
331
        self.success("Spectrum computed")
332

    
333
    def _makebutterfly(self):
334
        """make the butterfly by appending element in a table"""
335
        for i in xrange(self.m_npt):
336
            self.m_but.append(self.m_flux[i]+self.m_error[i])
337
            self.m_enebut.append(self.m_energy[i])
338
        for i in xrange(self.m_npt):
339
            idx = self.m_npt-i-1
340
            self.m_but.append(self.m_flux[idx]-self.m_error[idx])
341
            self.m_enebut.append(self.m_energy[idx])
342
        self.m_but.append(self.m_but[0])
343
        self.m_enebut.append(self.m_enebut[0])
344

    
345
    def write(self):
346
        self.warning("write function not yet implemented")
347

    
348
    def draw(self):
349
        """plot the results, provide a file named ptfile for the flux points"""
350
        gs = gridspec.GridSpec(2, 1,height_ratios=[3,1])
351
        self._makespectrum()
352
        self.write()
353

    
354
        if not(has_matplotlib):
355
            self.warning("matplotlib module not found, can draw")
356
            return
357

    
358
        plt.figure("results",figsize=(12,7),edgecolor="w")
359

    
360
        sb1 = plt.subplot(gs[0])
361
        plt.loglog()
362
        plt.xlabel("Energy ["+ self.eunit+"]",fontsize=15)
363
        
364
        if self.m_sed:
365
            plt.ylabel("E$^{2}$ dN/dE [erg cm$^{2}$ s$^{-1}$]",fontsize=15)
366
        else:
367
            plt.ylabel("dN/dE [cm$^{-2}$ s$^{-1}$ "+ self.eunit+"$^{-1}$]",fontsize=15)
368

    
369
        plt.plot(self.m_energy,self.m_flux)
370
        plt.plot(self.m_enebut,self.m_but)
371
        
372
        #~ plot the data points on top of the buterfly
373
        self.points._plt_points()
374
        self.points._plt_limits()
375

    
376
        #~ plot the residual points
377
        sb2 = plt.subplot(gs[1])     
378
        plt.xscale('log')     
379
        self.residuals._plt_residuals()
380

    
381
        plt.show()
382

    
383

    
384
# ============= #
385
# Plot skymaps  #
386
# ============= #
387
class MapPlotter(analysisutils.base):
388
    def __init__(self,modelmap, countmap):
389
        super(MapPlotter,self).__init__()
390
        self.modelmap = modelmap
391
        self.countmap = countmap
392
        
393
        if not(has_aplpy) :
394
            self.error("Module aplypy needed")
395
        if not(has_matplotlib) :
396
            self.error("Module matplotlib needed")
397
            
398
    def draw(self):
399
        """ Plot the model map.    """
400
        # Load model map
401
        fig = plt.figure()
402
        f1 = aplpy.FITSFigure(self.countmap, figure=fig, subplot=[0.1,0.1,0.35,0.8])
403
        f1.set_tick_labels_font(size='small')
404
        f1.set_axis_labels_font(size='small')
405
        f1.show_colorscale()
406

    
407
        f1.tick_labels.set_yformat('dd:mm:ss')
408
        f1.tick_labels.set_xformat('hh:mm:ss')
409
        f1.axis_labels.set_xtext('Right Ascension (J2000)')
410
        f1.axis_labels.set_ytext('Declination (J2000)')
411
        f1.ticks.set_length(10.5, minor_factor=0.5)
412

    
413
        f2 = aplpy.FITSFigure(self.modelmap, figure=fig, subplot=[0.5,0.1,0.35,0.8])
414
        f2.set_tick_labels_font(size='small')
415
        f2.set_axis_labels_font(size='small')
416
        f2.show_colorscale()
417

    
418
        f2.hide_yaxis_label()
419
        f2.hide_ytick_labels()
420

    
421
        fig.canvas.draw()
422
        fig.show()