# Load packages JM and lattice
library("JM")
library("lattice")


# indicator for the composite event for the PBC dataset
pbc2$status2 <- as.numeric(pbc2$status != "alive")
pbc2.id$status2 <- as.numeric(pbc2.id$status != "alive")


#################
# Section 6.1.1 #
#################

lmeFit.aids <- lme(CD4 ~ obstime + obstime:drug,
    random = ~ obstime | patient, data = aids)

coxFit.aids <- coxph(Surv(Time, death) ~ drug, data = aids.id,
    x = TRUE)

jointFit.aids <- jointModel(lmeFit.aids, coxFit.aids,
    timeVar = "obstime", method = "piecewise-PH-aGH")


# default plots
par(mfrow = c(2, 2))
plot(jointFit.aids)

# marginal residuals
resMargY.aids <- residuals(jointFit.aids, process = "Longitudinal",
    type = "Marginal")

# marginal fitted values
fitMargY.aids <- fitted(jointFit.aids, process = "Longitudinal",
    type = "Marginal")


# function to produce scatteplots with superimposed smooth line
plotResid <- function (x, y, col.loess = "black", ...) {
    plot(x, y, ...)
    lines(lowess(x, y), col = col.loess, lwd = 2)
    abline(h = 0, lty = 3, col = "grey", lwd = 2)
}

# scatteplot of marginal residuals vs marginal fitted values
plotResid(fitMargY.aids, resMargY.aids, xlab = "Fitted Values",
    ylab = "Marginal Residuals")


#################
# Section 6.1.2 #
#################

# martingale residuals
martRes <- residuals(jointFit.aids, process = "Event")

# subject-specific fitted values for the longitudinal outcome
mi.t <- fitted(jointFit.aids, process = "Longitudinal",
    type = "EventTime")

# scatteplot of martingale residuals vs subject-specific fitted values
plotResid(mi.t, martRes, col.loess = "grey62",
    ylab = "Martingale Residuals",
    xlab = "Subject-Specific Fitted Values Longitudinal Outcome")

# the same scatterplot per treatment group
xyplot(martRes ~ mi.t | drug, data = aids, type = c("p", "smooth"),
    col = "black", lwd = 3, ylab = "Martingale Residuals",
    xlab = "Subject-Specific Fitted Values Longitudinal Outcome")


# Cox-Snell residuals
resCST <- residuals(jointFit.aids, process = "Event",
    type = "CoxSnell")

# Kaplan-Meier estimate of the survival function of the
# Cox-Snell residuals
sfit <- survfit(Surv(resCST, death) ~ 1, data = aids.id)
plot(sfit, mark.time = FALSE, conf.int = TRUE,
    xlab = "Cox-Snell Residuals", ylab = "Survival Probability",
    main = "Survival Function of Cox-Snell Residuals")

# superimpose the survival function of the unit exponential
# distribution
curve(exp(-x), from = 0, to = max(aids.id$Time), add = TRUE,
    col = "grey62", lwd = 2)


# Kaplan-Meier estimate of the survival function of the
# Cox-Snell residuals per treatment group
sfit <- survfit(Surv(resCST, death) ~ drug, data = aids.id)
plot(sfit, mark.time = FALSE, xlab = "Cox-Snell Residuals",
    ylab = "Survival Probability",
    main = "Survival Function of Cox-Snell Residuals")

#survival function of the unit exponential distribution
curve(exp(-x), from = 0, to = max(aids.id$Time), add = TRUE,
    col = "grey62", lwd = 2)


###############
# Section 6.2 #
###############

# a joint mode for the PBC dataset
lmeFit2.pbc <- lme(log(serBilir) ~ year * drug,
    random = ~ year | id, data = pbc2)

coxFit.pbc <- coxph(Surv(years, status2) ~ drug + hepatomegaly,
    data = pbc2.id, x = TRUE)

jointFit2.pbc <- jointModel(lmeFit2.pbc, coxFit.pbc, timeVar = "year",
    method = "piecewise-PH-aGH")

# subject-specific residuals and fitted values
resSubY.pbc <- residuals(jointFit2.pbc, process = "Longitudinal",
    type = "stand-Subject")
fitSubY.pbc <- fitted(jointFit2.pbc, process = "Longitudinal",
    type = "Subject")

# marginal residuals and fitted values
resMargY.pbc <- residuals(jointFit2.pbc, process = "Longitudinal",
    type = "stand-Marginal")
fitMargY.pbc <- fitted(jointFit2.pbc, process = "Longitudinal",
    type = "Marginal")

# scatterplots of subject-specific and marginal residuals vs
# fitted values
par(mfrow = c(1,2))
plotResid(fitSubY.pbc, resSubY.pbc, xlab = "Fitted Values",
    ylab = "Subject-Specific Residuals")
plotResid(fitMargY.pbc, resMargY.pbc, xlab = "Fitted Values",
    ylab = "Marginal Residuals")


#################
# Section 6.3.1 #
#################

# multiple-imputation residuals
set.seed(123) # we set the seed for reproducibility
resMI.aids <- residuals(jointFit.aids, process = "Longitudinal",
    type = "Marginal", MI = TRUE)

# extract the residuals and fitted values corresponding to
# missing cases
fitMargYmiss.aids <- resMI.aids$fitted.valsM
resMargYmiss.aids <- resMI.aids$resid.valsM


# scatterplot of observed + multiply imputed residuals
# and fitted values
M <- ncol(resMargYmiss.aids) # number of imputations
resMargYmi.aids <- c(resMargY.aids, resMargYmiss.aids)
fitMargYmi.aids <- c(fitMargY.aids, rep(fitMargYmiss.aids, M))
plot(range(fitMargYmi.aids), range(resMargYmi.aids), type = "n",
    xlab = "Fitted Values",
    ylab = "MI Standardized Marginal Residuals")
abline(h = 0, lty = 2)
points(rep(fitMargYmiss.aids, M), resMargYmiss.aids, cex = 0.5,
    col = "grey")
points(fitMargY.aids, resMargY.aids)

# loess smoother based on observed data alone
lines(lowess(fitMargY.aids, resMargY.aids), lwd = 2)

# loess smoother based on observed + multiply imputed data
datResid <- data.frame(
    resid = resMargYmi.aids,
    fitted = fitMargYmi.aids,
    weight = c(rep(1, length(resMargY.aids)),
        rep(1/M, length(resMargYmiss.aids)))
)
fitLoess.aids <- loess(resid ~ fitted, data = datResid,
    weights = weight)
nd.aids <- data.frame(fitted = seq(min(fitMargYmi.aids),
    max(fitMargYmi.aids), length.out = 100))
prdLoess.aids <- predict(fitLoess.aids, nd.aids)
lines(nd.aids$fit, prdLoess.aids, lwd = 2, lty = 2)


#################
# Section 6.3.2 #
#################

# construct data to fit the visiting process model
diff.time <- with(pbc2, tapply(year, id, diff))
prev.y <- with(pbc2, tapply(log(serBilir), id, head, -1))
one.visit <- sapply(diff.time, length) == 0
diff.time[one.visit] <- prev.y[one.visit] <- NA

dataVT <- data.frame(
    "id" = rep(names(prev.y), sapply(prev.y, length)),
    "diff.Times" = unlist(diff.time),
    "prev.y" = unlist(prev.y),
    "event" = 1)

# fit of the Weibull model with a Gamma frailty for
# the visiting process
WeibFrl <- weibull.frailty(Surv(diff.Times, event) ~ prev.y,
    id = "id", data = dataVT)

summary(WeibFrl)


# calculate multiply imputed residuals using the visiting
# process model
set.seed(123) # we set the seed for reproducibility
resMI.pbc <- residuals(jointFit2.pbc, type = "stand-Marginal",
    MI = TRUE, M = 10, time.points = WeibFrl)

# extract the residuals and fitted values corresponding to
# missing cases
fitMargYmiss.pbc <- unlist(resMI.pbc$fitted.valsM)
resMargYmiss.pbc <- unlist(resMI.pbc$resid.valsM)

# dataset containing observed + multiply imputed residuals
# and fitted values
M <- length(resMI.pbc$fitted.valsM)
datResid <- data.frame(
    resid = c(resMargY.pbc, resMargYmiss.pbc),
    fitted = c(fitMargY.pbc, fitMargYmiss.pbc),
    weight = c(rep(1, length(resMargY.pbc)),
        rep(1/M, length(resMargYmiss.pbc))))

datResid <- datResid[complete.cases(datResid), ]

# loess smoother based on observed + multiply imputed data
fitLoess.pbc <- loess(resid ~ fitted, data = datResid,
    weights = weight)
nd.pbc <- data.frame(fitted = seq(min(datResid$fitted, na.rm = TRUE),
    max(datResid$fitted, na.rm = TRUE), len = 100))

prdLoess.pbc <- predict(fitLoess.pbc, nd.pbc)

# scatterplot of observed + multiply imputed residuals
# and fitted values
plot(range(fitMargYmiss.pbc, na.rm = TRUE),
    range(resMargYmiss.pbc, na.rm = TRUE),
    type = "n", xlab = "Fitted Values",
    ylab = "Standardized Marginal Residuals")
abline(h = 0, lty = 2)
points(fitMargY.pbc, resMargY.pbc)
points(fitMargYmiss.pbc, resMargYmiss.pbc, col = "grey")
lines(lowess(fitMargY.pbc, resMargY.pbc), lwd = 2)
lines(nd.pbc$fit, prdLoess.pbc, lwd = 2, lty = 2)