Issues with writer.write_hdf5 (model.predict(pbest, obs=obs, sps=sps))
Hi,
I successfully ran my emcee_sampler however I am having issues with saving the output. It seems that the code is failing on model.predict after calling write_hdf5.
If you could help me resolve this it would be greatly appreciated.
Here is the code and error:
Relevant part of the code ` start_time = time.time()
AGN_file = fits.open(AGN_FITS_PATH)
AGN_data = AGN_file[1].data
gal_desig = AGN_data[galaxy_num][1]
# 4: ### Choose a galaxy (0 to 57) ### -------------------------------------------------------------------
galaxy_num = galaxy_num
Template_Type = Template_Type
# Create galaxy file to store plots and hdf5 data file
if not os.path.exists('{0}/G{1}/'.format(G_PATH, galaxy_num)): #'Galaxy_output/G{}/'
os.mkdir('{0}/G{1}/'.format(G_PATH, galaxy_num))
Galaxy_Path = '{0}/G{1}/'.format(G_PATH, galaxy_num)
print('{0}: This is for Galaxy {1} \t\t {2}'.format(time.strftime("%H:%M:%S", time.localtime()), galaxy_num, gal_desig))
ts = time.strftime("%y%b%d", time.localtime())
print('{0}: The Date is {1}'.format(time.strftime("%H:%M:%S", time.localtime()), ts))
print('{0}: The template type is {1}'.format(time.strftime("%H:%M:%S", time.localtime()), Template_Type))
theta_prediction, run_params, result = initialize_theta(galaxy_num, Template_Type, ts, AGN_data, gal_desig, Galaxy_Path, Run_Num, Num_Iters) #, input_hfile
print("{0} theta_prediction = {1}".format(time.strftime("%H:%M:%S", time.localtime()), theta_prediction))
# Re-run build functions with AGN
print('build obs')
obs = build_obs(**run_params)
print('build sps')
sps = build_sps(**run_params)
print('build model')
model = build_model(**run_params)
# 17: ### Run MCMC with emcee ### ------------------------------------------------------------------------
lnprobfn_fixed = lnprobfn #partial(prospect.fitting.lnprobfn, sps=sps, model=model, obs=obs)
# --- Run MCMC --- #
print('{1}: Start emcee for {0}'.format(run_params['ID'], time.strftime("%H:%M:%S", time.localtime())))
print('\tniter:', run_params['niter'])
print('\tnwalkers:', run_params['nwalkers'])
output = fit_model(obs, model, sps, lnprobfn=lnprobfn_fixed, **run_params) #OG
print('{2}: Finished emcee in {0:.2f}m for {1}'.format(output["sampling"][1]/60, run_params['ID'], time.strftime("%H:%M:%S", time.localtime())))
# 18: ### Create file path and re-run build functions ### ------------------------------------------------
hfile = Galaxy_Path + 'G{0}_{1}_{2}_res.h5'.format(galaxy_num, gal_desig, Run_Num)
obs, model, sps = build_obs(**run_params), build_model(**run_params), build_sps(**run_params)
# 19: ### Save results to h5 File ### -------------------------------------------------------------------
writer.write_hdf5(hfile, run_params, model, obs,
output["sampling"][0], output["optimization"][0],
tsample=output["sampling"][1],
toptimize=output["optimization"][1],
sps=sps)
print('{0}: Finished writing {1} file'.format(time.strftime("%H:%M:%S", time.localtime()), hfile))
# 20: ### Print time it takes to run ### -------------------------------------------------------------------
end_time = time.time()
print('{0}: This program takes:\n\t {1:.2f} \tsecs\n\t {2:.2f} \tmins\n\t {3:.2f} \thours'.format(time.strftime("%H:%M:%S", time.localtime()), (end_time - start_time), (end_time - start_time)/60, (end_time - start_time)/60/60))
return output`
Error Message:
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_4458/1401446808.py in
/tmp/ipykernel_4458/2085521713.py in PSB_AGN_CAPS_Funct(galaxy_num, Run_Num, Template_Type, Num_Iters) 283 284 # 19: ### Save results to h5 File ### ------------------------------------------------------------------- --> 285 writer.write_hdf5(hfile, run_params, model, obs, 286 output["sampling"][0], output["optimization"][0], 287 tsample=output["sampling"][1],
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/io/write_results.py in write_hdf5(hfile, run_params, model, obs, sampler, optimize_result_list, tsample, toptimize, sampling_initial_center, sps, **extras) 154 from ..utils.plotting import get_best 155 _, pbest = get_best(hf["sampling"]) --> 156 spec, phot, mfrac = model.predict(pbest, obs=obs, sps=sps) 157 best = hf.create_group("bestfit") 158 best.create_dataset("spectrum", data=spec)
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/models/sedmodel.py in predict(self, theta, obs, sps, **extras) 579 stellar mass formed. 580 """ --> 581 s, p, x = self.sed(theta, obs, sps=sps, **extras) 582 self._speccal = self.spec_calibration(obs=obs, **extras) 583 if obs.get('logify_spectrum', False):
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/models/sedmodel.py in sed(self, theta, obs, sps, **kwargs) 621 """ 622 self.set_parameters(theta) --> 623 spec, phot, extras = sps.get_spectrum(outwave=obs['wavelength'], 624 filters=obs['filters'], 625 component=obs.get('component', -1),
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/sources/ssp_basis.py in get_spectrum(self, outwave, filters, peraa, **params) 204 """ 205 # Spectrum in Lsun/Hz per solar mass formed, restframe --> 206 wave, spectrum, mfrac = self.get_galaxy_spectrum(**params) 207 208 # Redshifting + Wavelength solution
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/sources/ssp_basis.py in get_galaxy_spectrum(self, **params)
329 """Construct the tabular SFH and feed it to the ssp.
330 """
--> 331 self.update(**params)
332 mtot = self.params['mass'].sum()
333 time, sfr, tmax = self.convert_sfh(self.params['agebins'], self.params['mass'])
/mnt/c/Users/emma_d/ASTR_Research/lib/python3.8/site-packages/repo/prospector/prospect/sources/ssp_basis.py in update(self, **params) 107 # copy of it in. 108 if k in self.ssp.params.all_params: --> 109 self.ssp.params[k] = deepcopy(v) 110 111 # We use FSPS for SSPs !!ONLY!!
~/miniconda3/envs/fsps-test/lib/python3.9/site-packages/fsps/fsps.py in setitem(self, k, v) 1302 is_changed = original != v 1303 -> 1304 if is_changed: 1305 if k in self.ssp_params: 1306 self.dirtiness = 2
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()`
This is probably another issue with get_best, I suspect that the line
_, pbest = get_best(hf["sampling"])
is returning something non-standard for pbest. I don't have time at the moment to debug this.
But hopefully I can dig into this in the next week or so.