REST-for-Physics  v2.3
Rare Event Searches ToolKit for Physics
TRestDataSetOdds.cxx
1 /*************************************************************************
2  * This file is part of the REST software framework. *
3  * *
4  * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
5  * For more information see https://gifna.unizar.es/trex *
6  * *
7  * REST is free software: you can redistribute it and/or modify *
8  * it under the terms of the GNU General Public License as published by *
9  * the Free Software Foundation, either version 3 of the License, or *
10  * (at your option) any later version. *
11  * *
12  * REST is distributed in the hope that it will be useful, *
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
15  * GNU General Public License for more details. *
16  * *
17  * You should have a copy of the GNU General Public License along with *
18  * REST in $REST_PATH/LICENSE. *
19  * If not, see https://www.gnu.org/licenses/. *
20  * For the list of contributors see $REST_PATH/CREDITS. *
21  *************************************************************************/
22 
102 
103 #include "TRestDataSetOdds.h"
104 
105 #include "TRestDataSet.h"
106 
107 ClassImp(TRestDataSetOdds);
108 
113 
128 TRestDataSetOdds::TRestDataSetOdds(const char* configFilename, std::string name)
129  : TRestMetadata(configFilename) {
131  Initialize();
132 
134 }
135 
140 
145 void TRestDataSetOdds::Initialize() { SetSectionName(this->ClassName()); }
146 
152  Initialize();
154 
155  TiXmlElement* obsDefinition = GetElement("observable");
156  while (obsDefinition != nullptr) {
157  std::string obsName = GetFieldValue("name", obsDefinition);
158  if (obsName.empty() || obsName == "Not defined") {
159  RESTError << "< observable variable key does not contain a name!" << RESTendl;
160  exit(1);
161  } else {
162  fObsName.push_back(obsName);
163  }
164 
165  std::string range = GetFieldValue("range", obsDefinition);
166  if (range.empty() || range == "Not defined") {
167  RESTError << "< observable key does not contain a range value!" << RESTendl;
168  exit(1);
169  } else {
170  TVector2 roi = StringTo2DVector(range);
171  fObsRange.push_back(roi);
172  }
173 
174  std::string nBins = GetFieldValue("nBins", obsDefinition);
175  if (nBins.empty() || nBins == "Not defined") {
176  RESTError << "< observable key does not contain a nBins value!" << RESTendl;
177  exit(1);
178  } else {
179  fObsNbins.push_back(StringToInteger(nBins));
180  }
181 
182  obsDefinition = GetNextElement(obsDefinition);
183  }
184 
185  if (fObsName.empty() || fObsRange.empty()) {
186  RESTError << "No observables provided, exiting..." << RESTendl;
187  exit(1);
188  }
189 
190  if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", "");
191 
192  fCut = (TRestCut*)InstantiateChildMetadata("TRestCut");
193 }
194 
205  PrintMetadata();
206 
207  TRestDataSet dataSet;
208  dataSet.Import(fDataSetName);
209 
210  if (fOddsFile.empty()) {
211  auto DF = dataSet.MakeCut(fCut);
212  RESTInfo << "Generating PDFs for dataset: " << fDataSetName << RESTendl;
213  for (size_t i = 0; i < fObsName.size(); i++) {
214  const std::string obsName = fObsName[i];
215  const TVector2 range = fObsRange[i];
216  const std::string histName = "h" + obsName;
217  const int nBins = fObsNbins[i];
218  RESTDebug << "\tGenerating PDF for " << obsName << " with range: (" << range.X() << ", "
219  << range.Y() << ") and nBins: " << nBins << RESTendl;
220  auto histo =
221  DF.Histo1D({histName.c_str(), histName.c_str(), nBins, range.X(), range.Y()}, obsName);
222  TH1F* h = static_cast<TH1F*>(histo->DrawClone());
223  RESTDebug << "\tNormalizing by integral = " << h->Integral() << RESTendl;
224  h->Scale(1. / h->Integral());
225  fHistos[obsName] = h;
226  }
227  } else {
228  TFile* f = TFile::Open(fOddsFile.c_str());
229  if (f == nullptr) {
230  RESTError << "Cannot open calibration odds file " << fOddsFile << RESTendl;
231  exit(1);
232  }
233  RESTInfo << "Opening " << fOddsFile << " as oddsFile." << RESTendl;
234  for (size_t i = 0; i < fObsName.size(); i++) {
235  const std::string obsName = fObsName[i];
236  const std::string histName = "h" + obsName;
237  TH1F* h = (TH1F*)f->Get(histName.c_str());
238  fHistos[obsName] = h;
239  }
240  }
241 
242  auto df = dataSet.GetDataFrame();
243  std::string totName = "";
244  RESTDebug << "Computing log odds from " << fDataSetName << RESTendl;
245  for (const auto& [obsName, histo] : fHistos) {
246  const std::string oddsName = "odds_" + obsName;
247  auto GetLogOdds = [&histo = histo](double val) {
248  double odds = histo->GetBinContent(histo->GetXaxis()->FindBin(val));
249  if (odds == 0) return 1000.;
250  return log(1. - odds) - log(odds);
251  };
252 
253  if (df.GetColumnType(obsName) != "Double_t") {
254  RESTWarning << "Column " << obsName << " is not of type 'double'. It will be converted."
255  << RESTendl;
256  df = df.Redefine(obsName, "static_cast<double>(" + obsName + ")");
257  }
258  df = df.Define(oddsName, GetLogOdds, {obsName});
259  auto h = df.Histo1D(oddsName);
260 
261  if (!totName.empty()) totName += "+";
262  totName += oddsName;
263  }
264 
265  RESTDebug << "Computing total log odds" << RESTendl;
266  RESTDebug << "\tTotal log odds = " << totName << RESTendl;
267  df = df.Define("odds_total", totName);
268 
269  dataSet.SetDataFrame(df);
270 
271  if (!fOutputFileName.empty()) {
273  RESTDebug << "Exporting dataset to " << fOutputFileName << RESTendl;
274  dataSet.Export(fOutputFileName);
275  TFile* f = TFile::Open(fOutputFileName.c_str(), "UPDATE");
276  this->Write();
277  RESTDebug << "Writing histograms to " << fOutputFileName << RESTendl;
278  for (const auto& [obsName, histo] : fHistos) histo->Write();
279  f->Close();
280  }
281  }
282 }
283 
284 std::vector<std::tuple<std::string, TVector2, int>> TRestDataSetOdds::GetOddsObservables() {
285  std::vector<std::tuple<std::string, TVector2, int>> obs;
286  for (size_t i = 0; i < fObsName.size(); i++) {
287  if (i >= fObsName.size() || i >= fObsRange.size() || i >= fObsNbins.size()) {
288  RESTError << "Sizes for observables names, ranges and bins do not match!" << RESTendl;
289  break;
290  }
291  obs.push_back(std::make_tuple(fObsName[i], fObsRange[i], fObsNbins[i]));
292  }
293  return obs;
294 }
295 
296 void TRestDataSetOdds::AddOddsObservable(const std::string& name, const TVector2& range, int nbins) {
297  fObsName.push_back(name);
298  fObsRange.push_back(range);
299  fObsNbins.push_back(nbins);
300 }
301 
302 void TRestDataSetOdds::SetOddsObservables(const std::vector<std::tuple<std::string, TVector2, int>>& obs) {
303  fObsName.clear();
304  fObsRange.clear();
305  fObsNbins.clear();
306  for (const auto& [name, range, nbins] : obs) AddOddsObservable(name, range, nbins);
307 }
308 
314 
315  // if (fCut) fCut->PrintMetadata();
316  if (!fOddsFile.empty()) RESTMetadata << " Odds file: " << fOddsFile << RESTendl;
317  RESTMetadata << " DataSet file: " << fDataSetName << RESTendl;
318 
319  RESTMetadata << " Observables to compute: " << RESTendl;
320  for (size_t i = 0; i < fObsName.size(); i++) {
321  RESTMetadata << fObsName[i] << "; Range: (" << fObsRange[i].X() << ", " << fObsRange[i].Y()
322  << "); nBins: " << fObsNbins[i] << RESTendl;
323  }
324  RESTMetadata << "----" << RESTendl;
325 }
A class to help on cuts definitions. To be used with TRestAnalysisTree.
Definition: TRestCut.h:31
This class is meant to compute the log odds for different datasets.
std::string fOddsFile
Name of the odds file to be used to get the PDF.
std::vector< std::string > fObsName
Vector containing different obserbable names.
TRestDataSetOdds()
Default constructor.
void PrintMetadata() override
Prints on screen the information about the metadata members of TRestDataSetOdds.
std::map< std::string, TH1F * > fHistos
Map containing the PDF of the different observables.
std::vector< TVector2 > fObsRange
Vector containing different obserbable ranges.
void Initialize() override
Function to initialize input/output event members and define the section name.
~TRestDataSetOdds()
Default destructor.
TRestCut * fCut
Cuts over the dataset for PDF selection.
void ComputeLogOdds()
This function computes the log odds for a given dataSet. If no calibration odds file is provided it c...
std::string fOutputFileName
Name of the output file.
std::string fDataSetName
Name of the dataSet inside the config file.
std::vector< int > fObsNbins
Vector containing number of bins for the different observables.
void InitFromConfigFile() override
Function to initialize some variables from configfile.
It allows to group a number of runs that satisfy given metadata conditions.
Definition: TRestDataSet.h:34
void Import(const std::string &fileName)
This function imports metadata from a root file it import metadata info from the previous dataSet whi...
ROOT::RDF::RNode GetDataFrame() const
Gives access to the RDataFrame.
Definition: TRestDataSet.h:129
ROOT::RDF::RNode MakeCut(const TRestCut *cut)
This function applies a TRestCut to the dataframe and returns a dataframe with the applied cuts....
void Export(const std::string &filename, std::vector< std::string > excludeColumns={})
It will generate an output file with the dataset compilation. Only the selected branches and the file...
A base class for any REST metadata class.
Definition: TRestMetadata.h:74
virtual void PrintMetadata()
Implemented it in the derived metadata class to print out specific metadata information.
endl_t RESTendl
Termination flag object for TRestStringOutput.
TiXmlElement * GetElement(std::string eleDeclare, TiXmlElement *e=nullptr)
Get an xml element from a given parent element, according to its declaration.
Int_t LoadConfigFromFile(const std::string &configFilename, const std::string &sectionName="")
Give the file name, find out the corresponding section. Then call the main starter.
TRestMetadata * InstantiateChildMetadata(int index, std::string pattern="")
This method will retrieve a new TRestMetadata instance of a child element of the present TRestMetadat...
virtual void InitFromConfigFile()
To make settings from rml file. This method must be implemented in the derived class.
TRestStringOutput::REST_Verbose_Level GetVerboseLevel()
returns the verboselevel in type of REST_Verbose_Level enumerator
std::string GetFieldValue(std::string parName, TiXmlElement *e)
Returns the field value of an xml element which has the specified name.
void SetSectionName(std::string sName)
set the section name, clear the section content
std::string fConfigFileName
Full name of the rml file.
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
overwriting the write() method with fStore considered
TiXmlElement * GetNextElement(TiXmlElement *e)
Get the next sibling xml element of this element, with same eleDeclare.
std::string GetParameter(std::string parName, TiXmlElement *e, TString defaultValue=PARAMETER_NOT_FOUND_STR)
Returns the value for the parameter named parName in the given section.
@ REST_Info
+show most of the information for each steps
static std::string GetFileNameExtension(const std::string &fullname)
Gets the file extension as the substring found after the latest ".".
Definition: TRestTools.cxx:823
Int_t StringToInteger(std::string in)
Gets an integer from a string.
TVector2 StringTo2DVector(std::string in)
Gets a 2D-vector from a string.