PO_ParallelETUCT.hh
Go to the documentation of this file.
00001 
00012 #ifndef _PO_ParallelETUCT_HH_
00013 #define _PO_ParallelETUCT_HH_
00014 
00015 #include <rl_common/Random.h>
00016 #include <rl_common/core.hh>
00017 #include <rl_common/ExperienceFile.hh>
00018 
00019 #include "../Models/FactoredModel.hh"
00020 #include "../Models/C45Tree.hh"
00021 
00022 #include <set>
00023 #include <vector>
00024 #include <map>
00025 #include <sstream>
00026 #include <deque>
00027 
00029 void* poParallelSearchStart(void* arg);
00030 
00032 void* poParallelModelLearningStart(void* arg);
00033 
00035 class PO_ParallelETUCT: public Planner {
00036 public:
00037 
00054   PO_ParallelETUCT(int numactions, float gamma, float rrange, float lambda,
00055                    int MAX_ITER, float MAX_TIME, int MAX_DEPTH,  int modelType,
00056                    const std::vector<float> &featmax, const std::vector<float> &featmin,
00057                    const std::vector<int> &statesPerDim, bool trackActual, int historySize, 
00058                    Random rng = Random());
00059   
00062   PO_ParallelETUCT(const PO_ParallelETUCT &);
00063 
00064   virtual ~PO_ParallelETUCT();
00065 
00066   virtual void setModel(MDPModel* model);
00067   virtual bool updateModelWithExperience(const std::vector<float> &last, 
00068                                          int act, 
00069                                          const std::vector<float> &curr, 
00070                                          float reward, bool term);
00071   virtual void planOnNewModel();
00072   virtual int getBestAction(const std::vector<float> &s);
00073 
00074   virtual void setSeeding(bool seed);
00075   virtual void setFirst();
00076 
00077   bool PLANNERDEBUG;
00078   bool POLICYDEBUG; //= false; //true;
00079   bool MODELDEBUG;
00080   bool ACTDEBUG;
00081   bool UCTDEBUG;
00082   bool PTHREADDEBUG;
00083   bool ATHREADDEBUG;
00084   bool MTHREADDEBUG;
00085   bool TIMINGDEBUG;
00086   bool REALSTATEDEBUG;
00087   bool HISTORYDEBUG;
00088 
00090   MDPModel* model;
00091   
00093   MDPModel* modelcopy;
00094 
00098   typedef const std::vector<float> *state_t;
00099 
00100 
00101 
00103   // parallel stuff
00105 
00106   // tell if thread started
00107   bool modelThreadStarted;
00108   bool planThreadStarted;
00109 
00110   // the threads
00112   pthread_t planThread;
00113   
00115   pthread_t modelThread;
00116 
00117   // some variables that are locked
00119   std::vector<experience> expList;
00120   
00122   state_t discPlanState;
00123   
00125   std::vector<float> actualPlanState;
00126   
00128   state_t startState;
00129 
00130   // lock over simple objects
00132   pthread_mutex_t update_mutex;
00134   pthread_mutex_t nactions_mutex;
00136   pthread_mutex_t plan_state_mutex;
00138   pthread_mutex_t model_mutex;
00140   pthread_mutex_t list_mutex;
00142   pthread_mutex_t history_mutex;
00143 
00145   pthread_mutex_t statespace_mutex;
00146 
00147   // condition for when list is updated
00148   pthread_cond_t list_cond; 
00149 
00150 
00162   float uctSearch(const std::vector<float> &actS, state_t state, int depth);
00163   
00165   std::vector<float> selectRandomState();
00166   
00168   void parallelModelLearning();
00169   
00171   void parallelSearch();
00172   
00174   void loadPolicy(const char* filename);
00175   
00177   void logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax);
00178   
00180   std::vector<float> addVec(const std::vector<float> &a, const std::vector<float> &b);
00181   
00183   std::vector<float> subVec(const std::vector<float> &a, const std::vector<float> &b);
00184 
00185 protected:
00186 
00187 
00188   struct state_info;
00189   struct model_info;
00190 
00192   struct state_samples {
00193     std::vector<state_t> samples;
00194   };
00195 
00197   struct state_info {
00198 
00199     // data filled in from models
00200     StateActionInfo* model;
00201 
00202     // q values from policy creation
00203     std::vector<float> Q;
00204 
00205     // uct experience data
00206     int uctVisits;
00207     std::vector<int> uctActions;
00208     short unsigned int visited;
00209     short unsigned int id;
00210 
00211     // needs update
00212     bool needsUpdate;
00213 
00214     // mutex for model info, samples, everything else
00215     pthread_mutex_t statemodel_mutex;
00216     pthread_mutex_t stateinfo_mutex;
00217 
00218   };
00219 
00221   void initStateInfo(state_t s,state_info* info, int id);
00222   
00226   state_t canonicalize(const std::vector<float> &s);
00227 
00229   void deleteInfo(state_info* info);
00230   
00232   void createPolicy();
00233   
00235   void printStates();
00236   
00238   void calculateReachableStates();
00239   
00241   void removeUnreachableStates();
00242 
00244   void updateStateActionFromModel(state_t s, int a, state_info* info);
00245   
00247   void updateStateActionHistoryFromModel(const std::vector<float> modState, int a, StateActionInfo *newModel);
00248 
00250   double getSeconds();
00251 
00252   // uct stuff
00254   void resetAndUpdateStateActions();
00255 
00258   std::vector<float> simulateNextState(const std::vector<float> &actS, state_t state, state_info* info, int action, float* reward, bool* term);
00259   
00261   int selectUCTAction(state_info* info);
00262   
00264   void canonNextStates(StateActionInfo* modelInfo);
00265   
00267   void initStates();
00268   
00270   void fillInState(std::vector<float>s, int depth);
00271 
00272   virtual void savePolicy(const char* filename);
00273   
00275   std::vector<float> discretizeState(const std::vector<float> &s);
00276 
00277 private:
00278 
00282   std::set<std::vector<float> > statespace;
00283 
00285   std::map<state_t, state_info> statedata;
00286 
00287   std::vector<float> featmax;
00288   std::vector<float> featmin;
00289 
00291   std::deque<float> saHistory;
00292 
00293   state_t prevstate;
00294   int prevact;
00295   state_info* previnfo;
00296 
00297   double planTime;
00298   double initTime;
00299   double setTime;
00300   bool seedMode;
00301 
00302   int nstates;
00303   int nsaved;
00304   int nactions;
00305   int lastUpdate;
00306 
00307   bool timingType;
00308 
00309   const int numactions;
00310   const float gamma;
00311   const float rrange;
00312   const float lambda;
00313 
00314   const int MAX_ITER;
00315   const float MAX_TIME;
00316   const int MAX_DEPTH;
00317   const int modelType;
00318   const std::vector<int> &statesPerDim;
00319   const bool trackActual;
00320   const int HISTORY_SIZE;
00321   const int HISTORY_FL_SIZE;
00322 
00323   ExperienceFile expfile;
00324 };
00325 
00326 #endif


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13