import numpy as np
import matplotlib.pyplot as plt

import petpal.kinetic_modeling.tcms_as_convolutions as pet_tcm
import petpal.kinetic_modeling.reference_tissue_models as pet_rtms
import petpal.kinetic_modeling.fit_tac_with_rtms as fit_rtms
from petpal.visualizations.tac_plots import TacFigure as TACPlots
from petpal.utils.time_activity_curve import TimeActivityCurve

# Loading the input tac to generate a reference region tac
input_tac = TimeActivityCurve.from_tsv("../../../../../data/tcm_tacs/fdg_plasma_clamp_evenly_resampled.txt")
input_tac_resampled = input_tac.evenly_resampled_tac(8192)

# Generating a reference region tac
tac_times_in_minutes, ref_tac_vals = pet_tcm.gen_tac_1tcm_cpet_from_tac(tac_times=input_tac_resampled.times,
                                                                        tac_vals=input_tac_resampled.activity,
                                                                        k1=0.25, k2=0.2)
reference_tac = TimeActivityCurve(tac_times_in_minutes, ref_tac_vals)
test_params = dict(r1=1.0, k2=0.25, bp=3.0)
# Generating an SRTM tac
srtm_tac_vals = pet_rtms.calc_srtm_tac(tac_times_in_minutes=tac_times_in_minutes, ref_tac_vals=ref_tac_vals, **test_params)
srtm_tac = TimeActivityCurve(tac_times_in_minutes, srtm_tac_vals)


rtm_analysis = fit_rtms.FitTACWithRTMs(target_tac=srtm_tac,
                                       reference_tac=reference_tac,
                                       method='srtm')

# Performing the fit and saving the results
rtm_analysis.fit_tac_to_model()
fit_results = rtm_analysis.fit_results[0]
fit_results_dict = dict(r1=fit_results[0], k2=fit_results[1], bp=fit_results[2])

assert np.allclose(fit_results, list(test_params.values()))

# Generating the SRTM TAC from the fit results
fit_srtm_tac_vals = pet_rtms.calc_srtm_tac(tac_times_in_minutes=tac_times_in_minutes,
                                           ref_tac_vals=ref_tac_vals, **fit_results_dict)


# Plotting the results
tac_plt = TACPlots(ylabel=r'TAC $(\mathrm{nCi/ml})$')
tac_plt.add_tac(*input_tac.tac, label='PTAC', alpha=0.6, ls='--')
tac_plt.add_tac(tac_times_in_minutes, ref_tac_vals, label='Ref TAC', alpha=0.6, ls='--')
tac_plt.add_tac(tac_times_in_minutes[::50], srtm_tac_vals[::50], label='SRTM TAC', marker='x', color='black', ms=10)
tac_plt.add_tac(tac_times_in_minutes[::50], srtm_tac_vals[::50], label='SRTM TAC Fit', marker='o', color='red', ms=5)
plt.legend()
plt.ylim(0, None)
plt.show()