OnlineTDPrediction

class OnlineTDPrediction(trace_decay: float, eligibility: Type[pandemonium.traces.EligibilityTrace] = <class 'pandemonium.traces.AccumulatingTrace'>, criterion: callable = <function smooth_l1_loss>, **kwargs)

Bases: pandemonium.demons.online_td.OnlineTD, pandemonium.demons.prediction.TDPrediction

Semi-gradient \(\TD{(\lambda)}\) rule for estimating \(\tilde{v} \approx v_{\pi}\)

\[\begin{split}\begin{align*} e_t &= γ_t λ e_{t-1} + \nabla \tilde{v}(x_t) \\ w_{t+1} &= w_t + \alpha (z_t + γ_{t+1} \tilde{v}(x_{t+1}) - \tilde{v}(x_t))e_t \end{align*}\end{split}\]

References

“Reinforcement Learning: An Introduction”

Sutton and Barto (2018) ch. 12.2 http://incompleteideas.net/book/the-book.html

Methods Summary

delta(self, t)

Specifies the update rule for approximate value function (avf)

target(self, t, v)

One-step TD error target.

Methods Documentation

delta(self, t: pandemonium.experience.experience.Transition) → Tuple[Union[torch.Tensor, NoneType], dict]

Specifies the update rule for approximate value function (avf)

Since the algorithms in this family are online, the update rule is applied on every Transition.

target(self, t: pandemonium.experience.experience.Transition, v: torch.Tensor)

One-step TD error target.