This week I refined my testing features for the code, and made the completed doctests.
In my project, I was using 4 different kinds of dependencies, and none had been previously used by stingray, so I had to make some changes to the code structure to accomodate those exceptions.
I also made updated and improved a lot of docstrings and added some doctests, which was a new experience. My get_prior
doctest looked like:-
>>> if not can_sample:
... pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.")
>>> if not tfp_available:
... pytest.skip("Tensorflow probability required to make priors.")
>>> params_list = get_gp_params("RN", "gaussian")
Make a prior dictionary using tensorflow_probability distributions
>>> prior_dict = {
... "A": tfpd.Uniform(low = 1e-1, high = 2e+2),
... "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1),
... "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ),
... "arn": tfpd.Uniform(low = 0.1 , high = 2 ),
... "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)),
... }
>>> prior_model = get_prior(params_list, prior_dict)
This week I also made a lot of futile tries in resolving an error in jax. As good and fast a library it is (I am yet to use some mindblowing features like pytrees), it has one limitation (or design decision), is that we cannot use not fixed sized arrays inside a jit function.
In my project, one important issue that we wanted to tackle was the non stationarity of pulsar timeseries. The method to take this into account was to use a window over the data, ie only in the window we will asume a qpo and get its log likelihood and outside we will assume white noise outside. The problem in it was that jax jit functions wants to before handedly know the type and size of all its arrays and data structres, hense there was no way to create a window over the time-series.
This while frustrating, also is a proof that code is often used in ways very different than intended or there are issues that we not taken into account when these important design desisions were made. (Though at no fault of the Jax library as before hand knowing the array shapes is crucial for fast parallel code)