plotting.py

classes to plot - Sanchez David, 07/22/2014 04:27 PM

Download (14.6 KB)

 
1
import matplotlib.pyplot as plt
2
import matplotlib.gridspec as gridspec
3
from gammalib import *
4
import asciidata
5
import math
6
import os
7
import sys
8
import pyfits
9

    
10
import numpy as np
11

    
12
def ascii_dict(fname,names=[],**kwargs):
13
    d = dict()
14
    fdata = asciidata.open(fname,kwargs)
15
    if len(names)==0:
16
        for col in fdata:
17
            d[col.colname] = col.tonumpy()
18
    else:       
19
        for name in names:
20
            if fdata.find(name) >= 0:
21
                d[name] = fdata[name].tonumpy()
22
            else:
23
                d[name] = np.array([])
24
                print "WARNING: Column\"",name,"\" not found in",fname
25
    return d
26

    
27

    
28
class base(object):
29
    def __init__(self):
30
        self.classname = self.__class__.__name__
31
        self.errorcolor = "\033[31m"#red
32
        self.infocolor = "\033[34m"#blue
33
        self.warningcolor = "\033[33m"#yellow
34
        self.successcolor = "\033[32m"#green
35
        self.endcolor = "\033[0m"#reset
36
        self.prependfunction = True
37
        self.basemembers = ["classname","errorcolor","infocolor","warningcolor","successcolor","endcolor","prependfunction"]
38
        
39
    def error(self,message, functionname = ""):
40
        printstring = ""
41
        if functionname == "":
42
            printstring = "\n"+self.errorcolor+"*** Error ["+self.classname+"]: "+message+" ***\n"+self.endcolor
43
        else:
44
            printstring = "\n"+self.errorcolor+"*** Error ["+self.classname+"::"+functionname+"]: "+message+" ***\n"+self.endcolor
45
        sys.exit(printstring)
46
    
47
    def info(self,message,newline=True):
48
        printstring = self.infocolor+"["+self.classname+"]: "+message+self.endcolor     
49
        if newline:
50
            print printstring
51
        else:
52
            print self.infocolor+message+self.endcolor,
53
            sys.stdout.flush()  
54
        
55
    def warning(self,message,functionname = ""):
56
        printstring = ""
57
        if functionname == "":
58
            printstring = self.warningcolor+"["+self.classname+"] Warning: "+message+self.endcolor
59
        else:
60
            printstring = self.warningcolor+"["+self.classname+"::"+functionname+"] Warning: "+message+self.endcolor
61
        print printstring
62
    def success(self,message):
63
        printstring = self.successcolor+"["+self.classname+"]: "+message+self.endcolor
64
        print printstring
65
        
66
    def progress(self,message="."):
67
        string = self.infocolor+message+self.endcolor
68
        print string,
69

    
70

    
71
BLUE="#337EF5"
72
GREY='grey'
73
GREEN = '#21851D'
74
RED = '#BD0B0B'
75
YELLOW = '#D4CF44'#D1A206'
76
ORANGE = '#D19534'#EEBA76'
77

    
78
def xfontsize(sbplt,fontsize=18):  
79
    for tick in sbplt.xaxis.get_major_ticks():
80
        tick.label1.set_fontsize(fontsize)
81
def yfontsize(sbplt,fontsize=18):  
82
    for tick in sbplt.yaxis.get_major_ticks():
83
        tick.label1.set_fontsize(fontsize)
84
def xyfontsize(sbplt,fontsize=18):
85
    xfontsize(sbplt,fontsize)
86
    yfontsize(sbplt,fontsize)
87
    
88
def xticklines(sbplt,majorsize=7.0,majorwidth=1.5,minorsize=5.0,minorwidth=1.0):
89
    for line in sbplt.xaxis.get_majorticklines():
90
        line.set_markersize(majorsize)
91
        line.set_markeredgewidth(majorwidth)
92
    for line in sbplt.xaxis.get_minorticklines():
93
        line.set_markersize(minorsize)
94
        line.set_markeredgewidth(minorwidth)
95
def yticklines(sbplt,majorsize=7.0,majorwidth=1.5,minorsize=5.0,minorwidth=1.0):
96
    for line in sbplt.yaxis.get_majorticklines():
97
        line.set_markersize(majorsize)
98
        line.set_markeredgewidth(majorwidth)
99
    for line in sbplt.yaxis.get_minorticklines():
100
        line.set_markersize(minorsize)
101
        line.set_markeredgewidth(minorwidth)
102
def xyticklines(sbplt,majorsize=7.0,majorwidth=1.5,minorsize=5.0,minorwidth=1.0):
103
    xticklines(sbplt,majorsize,majorwidth,minorsize,minorwidth)
104
    yticklines(sbplt,majorsize,majorwidth,minorsize,minorwidth)
105

    
106

    
107
class ulimgraph(base):
108
    def __init__(self,filename,sed=True):
109
        super(ulimgraph,self).__init__()
110
        self.scalex = 1.0
111
        self.scaley = 1.0
112
        self.fmt='o'
113
        self.color = BLUE
114
        self.tslim = 4.0
115
        self.elinewidth = 2.0
116
        self.capsize=5.0
117
        self.markersize=5.0
118
        self.limlength=0.4
119
        self.sed = sed  
120
        self.label ="_nolegend_"
121
        self.xerrors=True  
122
        self.nlims = -1
123
        self._load_dict(filename)  
124
        self.filename =filename
125
        
126
        
127
    def _load_dict(self,filename):
128
        try:
129
            self.dict = ascii_dict(filename)
130
        except IOError:
131
            self.warning("File "+filename+" not existing - cant draw graph")
132
            self.dict ={}
133
            self.dict["flux"] = np.array([])
134
            self.dict["eu_flux"] = np.array([])
135
            self.dict["ed_flux"] = np.array([])
136
            self.dict["ulim_flux"] = np.array([])
137
            self.dict["ener"] = np.array([])
138
            
139
        if not "ener" in self.dict.keys():
140
            self.warning("Keyword \"ener\" not found in file "+filename)
141
            if "energy" in self.dict.keys():
142
                self.info("Found \"energy\" instead")
143
                self.dict["ener"] = self.dict["energy"]
144
                if "ed_energy" in self.dict.keys() and "eu_energy" in self.dict.keys():
145
                    self.dict["ed_ener"] = self.dict["ed_energy"].copy()
146
                    self.dict["eu_ener"] = self.dict["eu_energy"].copy()
147
                else:
148
                    self.dict["ed_energy"] = np.zeros(len(self.dict["ener"]))
149
                    self.dict["eu_energy"] = np.zeros(len(self.dict["ener"]))
150
            elif "emean" in self.dict.keys():
151
                self.info("Found \"emean\" instead")
152
                self.dict["ener"] = self.dict["emean"]
153
                if "ed_emean" in self.dict.keys() and "eu_emean" in self.dict.keys():
154
                    self.dict["ed_ener"] = self.dict["ed_emean"].copy()
155
                    self.dict["eu_ener"] = self.dict["eu_emean"].copy()
156
                else:
157
                    self.dict["ed_ener"] = np.zeros(len(self.dict["ener"]))
158
                    self.dict["eu_ener"] = np.zeros(len(self.dict["ener"]))
159
            else:
160
                self.error("Required parameter \"ener\" not found in file "+filename)
161
        
162
        if not "flux" in self.dict.keys():
163
            self.error("Required parameter \"flux\" not found in file "+filename)
164
        if "e_flux" in self.dict.keys():
165
            self.warning("only symmetrical errors found")
166
            self.dict["ed_flux"] = self.dict["e_flux"].copy()
167
            self.dict["eu_flux"] = self.dict["e_flux"].copy()
168
        if "ulim" in self.dict.keys():
169
            self.dict["ulim_flux"] = self.dict["ulim"].copy()
170
                
171
        comment = ""
172
        try:
173
            f = asciidata.open(filename)
174
            comment = f["flux"].colcomment
175
        except IOError:
176
            pass
177

    
178
        if self.sed:
179
            if "/MeV" in comment:
180
                self.dict["flux"]*=pow(self.dict["ener"],2.0)
181
                self.dict["eu_flux"]*=pow(self.dict["ener"],2.0)
182
                self.dict["ed_flux"]*=pow(self.dict["ener"],2.0)
183
                self.dict["ulim_flux"]*=pow(self.dict["ener"],2.0)
184
        else:
185

    
186
            if "/MeV" not in comment:
187
                self.dict["flux"]/=pow(self.dict["ener"],2.0)
188
                self.dict["eu_flux"]/=pow(self.dict["ener"],2.0)
189
                self.dict["ed_flux"]/=pow(self.dict["ener"],2.0)       
190
                self.dict["ulim_flux"]/=pow(self.dict["ener"],2.0)       
191
       
192
    def _plt_points(self):
193
        ed_ener = []
194
        eu_ener = []
195
        ener = []
196
        ed_flux = []
197
        eu_flux = []
198
        flux = []
199
        for i in range(len(self.dict["flux"])):
200
            if self.dict["TS"][i] >= self.tslim and self.dict["ed_flux"][i]<self.dict["flux"][i]:
201
                ed_ener.append(self.dict["ed_ener"][i])
202
                eu_ener.append(self.dict["eu_ener"][i])
203
                ener.append(self.dict["ener"][i])
204
                ed_flux.append(self.dict["ed_flux"][i])
205
                eu_flux.append(self.dict["eu_flux"][i])
206
                flux.append(self.dict["flux"][i])
207
                if self.dict["ed_flux"][i]>self.dict["flux"][i]:
208
                    self.warning("in "+os.path.basename(self.filename).replace(".sp","")+": ")
209
                    self.warning("\tflux error: "+str(self.dict["ed_flux"][i])+" is larger than flux value "+str(self.dict["flux"][i]))
210
                    self.warning("\tTS value is however TS="+str(self.dict["TS"][i])+", weird!")
211
                
212
        if self.xerrors:
213
            plt.errorbar(ener, flux, xerr=[ed_ener,eu_ener],yerr=[ed_flux,eu_flux],fmt=self.fmt,elinewidth=self.elinewidth,label=self.label,mfc = self.color,ecolor=self.color,mec = self.color,ms =self.markersize)               
214
        else:
215
            plt.errorbar(ener, flux,yerr=[ed_flux,eu_flux],fmt=self.fmt,elinewidth=self.elinewidth,label=self.label,mfc = self.color,ecolor=self.color,mec = self.color,ms =self.markersize)               
216

    
217
    def _plt_limits(self):
218
        ed_ener = []
219
        eu_ener = []
220
        ener = []
221
        ed_flux = []
222
        eu_flux = []
223
        flux = []
224
        drawn_lims = 0
225
        for i in range(len(self.dict["flux"])):
226
            if self.dict["TS"][i] < self.tslim or self.dict["ed_flux"][i]>self.dict["flux"][i]:
227
                if drawn_lims >= self.nlims and self.nlims!=-1:
228
                    break
229
                ed_ener.append(self.dict["ed_ener"][i])
230
                eu_ener.append(self.dict["eu_ener"][i])
231
                ener.append(self.dict["ener"][i])
232
                ed_flux.append(self.limlength*self.dict["ulim_flux"][i])
233
                eu_flux.append(0.0)
234
                flux.append(self.dict["ulim_flux"][i])
235
                drawn_lims+=1
236
                
237
        if self.xerrors:
238
            plt.errorbar(ener, flux, xerr=[ed_ener,eu_ener],yerr=[ed_flux,eu_flux],fmt=self.fmt,markersize=0.0,elinewidth=self.elinewidth,lolims=True,capsize=self.capsize,label='_nolegend_',mfc = self.color,ecolor=self.color,ms =self.markersize)      
239
        else:
240
            plt.errorbar(ener, flux, yerr=[ed_flux,eu_flux],fmt=self.fmt,markersize=0.0,elinewidth=self.elinewidth,lolims=True,capsize=self.capsize,label='_nolegend_',mfc = self.color,ecolor=self.color,ms =self.markersize)      
241
            
242
    def draw(self):
243
        for key in self.dict.keys():
244
            if "ener" in key:
245
                self.dict[key]*=self.scalex;
246
            elif "flux" in key:    
247
                self.dict[key]*=self.scaley;
248
      
249
        self._plt_points()
250
        self._plt_limits()
251

    
252

    
253
        
254

    
255
class residuals(base):
256
    def __init__(self,specfilename,modelfilename):
257
        super(residuals,self).__init__()     
258
        self.scalex = 1.0
259
        self.modelfilename = modelfilename
260
        self.source = os.path.basename(modelfilename).replace(".optmodel","")
261
        self.xerrors = True
262
        self.fmt = 'o'
263
        self.markercolor = BLUE
264
        self.linecolor = 'black'
265
        self.tslim = 4.0
266
        self.elinewidth=2.0
267
        self.linewidth=2.0
268
        self.linestyle = '--'
269
        self._load_dict(specfilename)
270
    
271
    def _load_dict(self,filename):       
272
        gr = ulimgraph(filename,False)
273
        self.dict = gr.dict.copy()
274
    def _plt_residuals(self):
275
        ener = []
276
        ed_ener = []
277
        eu_ener = []
278
        flux = []
279
        ed_flux = []
280
        eu_flux = []
281
        time = GTime()
282
        for i in range(len(self.dict["flux"])):
283
            if self.dict["TS"][i] >= self.tslim and self.dict["ed_flux"][i]<self.dict["flux"][i]:
284
                th_val = self.spectral.eval(self.energies[i],time)
285
                ed_ener.append(self.dict["ed_ener"][i])
286
                eu_ener.append(self.dict["eu_ener"][i])
287
                ener.append(self.dict["ener"][i])              
288
                flux.append((self.dict["flux"][i] - th_val) / th_val)
289
                ed_flux.append(self.dict["ed_flux"][i]/th_val)
290
                eu_flux.append(self.dict["ed_flux"][i]/th_val)
291
        if self.xerrors:
292
            plt.errorbar(ener, flux, xerr=[ed_ener,eu_ener],yerr=[ed_flux,eu_flux],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
293
        else:
294
            plt.errorbar(ener, flux,yerr=[ed_flux,eu_flux],fmt=self.fmt,color=self.markercolor,elinewidth=self.elinewidth) 
295
        
296
    def draw(self):
297
        models = GModels(self.modelfilename)
298
        try:       
299
            self.spectral = models[self.source].spectral().clone() 
300
        except:
301
            self.error("Source \""+self.source+"\" not found in model container "+self.modelfilename)       
302
        
303
        self.energies = []
304
        for e in self.dict["ener"]:
305
            self.energies.append(GEnergy(e,"MeV"))
306
        for key in self.dict.keys():
307
            if "ener" in key:
308
                self.dict[key]*=self.scalex;
309

    
310
        self._plt_residuals()
311
        plt.axhline(0.0,color=self.linecolor,lw=self.linewidth,ls=self.linestyle)
312

    
313

    
314
     
315

    
316
class plotter(base):
317
    def __init__(self,roiname):
318
        super(plotter,self).__init__()
319
        self.xlim = [0,0]
320
        self.ylim = [0,0]
321
        self.ylim_res = [0,0]
322
        self.resultpath = os.environ["PWD"]
323

    
324
        self.roiname = roiname 
325
        self.tslim = 4.0
326
        self.scalex = 1e-6
327
        self.scaley = 1.60217e-6
328

    
329
           
330
    def run(self):
331
         
332
        prefix = self.resultpath+"/"+self.roiname
333
        spfile = prefix+".sp"
334
        optmodel = prefix+".xml"
335

    
336
        models = GModels(optmodel)
337
        
338
        residualplot = residuals(spfile,optmodel)
339

    
340
        residualplot.scalex = self.scalex
341
        residualplot.source=self.roiname
342
        residualplot.tslim = self.tslim
343
        gr1 = ulimgraph(spfile)
344
        gr1.tslim = self.tslim
345
        gr1.scalex = self.scalex
346
        gr1.scaley = self.scaley
347
    
348
        plt.figure(num=self.roiname,figsize=(12,7),edgecolor="w")
349
        gs = gridspec.GridSpec(2, 1,height_ratios=[3,1])
350
        sb1 = plt.subplot(gs[0])
351
        plt.loglog()
352
        gr1.draw()
353

    
354
        plt.ylabel("E$^{2}$ dN/dE [erg cm$^{2}$ s$^{-1}$]",fontsize=20)
355

    
356
        xyfontsize(sb1, fontsize=18)
357
        xyticklines(sb1,7.0,1.5,5.0,1.0)
358
        if not self.xlim[0] == 0: 
359
            plt.xlim(xlim)
360
        else:
361
            self.xlim = plt.xlim()
362
        if not self.ylim[0] == 0:
363
            plt.ylim(ylim)   
364

    
365
        plt.setp(plt.gca(), 'xticklabels', [])
366
        plt.tight_layout(pad=0.3,w_pad=0.3, h_pad=0.5) 
367
        sb2 = plt.subplot(gs[1])    
368
        plt.xscale('log')
369
        residualplot.draw()
370
        
371
        plt.xlim(self.xlim) 
372
        if not self.ylim_res[0] == 0:
373
            plt.ylim(self.ylim_res)
374
 
375
        plt.locator_params(axis='y', nbins=5)
376
        plt.xlabel("Energy [TeV]",fontsize=20)
377
        plt.ylabel("residual",fontsize=20)
378
        xyfontsize(sb2, fontsize=18)
379
        xyticklines(sb2,7.0,1.5,5.0,1.0)
380
        
381
        plt.subplots_adjust(left=0.11, right=0.97, top=0.9, bottom=0.11)
382
        plt.show()
383

    
384

    
385
if __name__ == "__main__":
386
    pl = plotter("CrabNebula")
387

    
388
    pl.run()