OnlineTDPrediction¶
-
class
OnlineTDPrediction
(trace_decay: float, eligibility: Type[pandemonium.traces.EligibilityTrace] = <class 'pandemonium.traces.AccumulatingTrace'>, criterion: callable = <function smooth_l1_loss>, **kwargs)¶ Bases:
pandemonium.demons.online_td.OnlineTD
,pandemonium.demons.prediction.TDPrediction
Semi-gradient \(\TD{(\lambda)}\) rule for estimating \(\tilde{v} \approx v_{\pi}\)
\[\begin{split}\begin{align*} e_t &= γ_t λ e_{t-1} + \nabla \tilde{v}(x_t) \\ w_{t+1} &= w_t + \alpha (z_t + γ_{t+1} \tilde{v}(x_{t+1}) - \tilde{v}(x_t))e_t \end{align*}\end{split}\]References
- “Reinforcement Learning: An Introduction”
Sutton and Barto (2018) ch. 12.2 http://incompleteideas.net/book/the-book.html
Methods Summary
delta
(self, t)Specifies the update rule for approximate value function (avf)
target
(self, t, v)One-step TD error target.
Methods Documentation
-
delta
(self, t: pandemonium.experience.experience.Transition) → Tuple[Union[torch.Tensor, NoneType], dict]¶ Specifies the update rule for approximate value function (avf)
Since the algorithms in this family are online, the update rule is applied on every Transition.
-
target
(self, t: pandemonium.experience.experience.Transition, v: torch.Tensor)¶ One-step TD error target.