# Load packages JM and lattice
library("JM")
library("lattice")

# indicator for the composite event for the PBC dataset
pbc2$status2 <- as.numeric(pbc2$status != "alive")
pbc2.id$status2 <- as.numeric(pbc2.id$status != "alive")

# indicator for abnormal prothrombin time
pbc2.id$Pro <- with(pbc2.id, factor(prothrombin >= 10 & prothrombin <= 13,
    labels = c("Abnormal", "Normal")))
pbc2$Pro <- rep(pbc2.id$Pro, tapply(pbc2$id, pbc2$id, length))


#################
# Section 7.1.3 #
#################

lmeFitBsp.pbc <- lme(
    fixed = log(serBilir) ~ bs(year, 4, Boundary.knots = c(0, 15)),
    random = list(
        id = pdDiag(form = ~ bs(year, 4, Boundary.knots = c(0, 15)))),
    data = pbc2)

coxFit.pbc <- coxph(Surv(years, status2) ~ drug + Pro,
    data = pbc2.id, x = TRUE)

jointFitBsp.pbc <- jointModel(lmeFitBsp.pbc, coxFit.pbc,
    timeVar = "year", method = "piecewise-PH-aGH")

# conditional survival probabilities for Patient 2 using Monte Carlo
set.seed(123) # we set the seed for reproducibility
survPrbs <- survfitJM(jointFitBsp.pbc, newdata = pbc2[pbc2$id == 2, ])
survPrbs

# conditional survival probabilities for Patient 2 using empirical Bayes
survPrbsEB <- survfitJM(jointFitBsp.pbc,
    newdata = pbc2[pbc2$id == 2, ], simulate = FALSE)
survPrbsEB

# including the last time point we know the patient was alive
set.seed(123)
survPrbs2 <- survfitJM(jointFitBsp.pbc,
    newdata = pbc2[pbc2$id == 2, ], last.time = "years")
survPrbs2

# conditional survival probabilities at specific time points
set.seed(123)
survfitJM(jointFitBsp.pbc, newdata = pbc2[pbc2$id == 2, ],
    survTimes = c(14.5, 15), last.time = "years")


# the plot method depicts the estimates of the conditional survival
# probabilities
plot(survPrbs, lty = c(1:2,3,3), conf.int = TRUE)

# a for-loop to estimate the dynamic survival probabilities
ND <- pbc2[pbc2$id == 2, ]
survPreds <- vector("list", nrow(ND))
for (i in 1:nrow(ND)) {
    set.seed(123)
    survPreds[[i]] <- survfitJM(jointFitBsp.pbc, newdata = ND[1:i, ])
}

# plot of the dynamic survival probabilities
par(mfrow = c(2, 2), oma = c(0, 2, 0, 2))
for (i in c(1,3,5,7)) {
    plot(survPreds[[i]], estimator = "median", conf.int = TRUE,
        include.y = TRUE, main = paste("Follow-up time:",
            round(survPreds[[i]]$last.time, 1)))
}
mtext("log serum bilirubin", side = 2, line = -1, outer = TRUE)
mtext("Survival Probability", side = 4, line = -1, outer = TRUE)


# code to reproduce Figure 7.4
DD <- pbc2[pbc2$id %in% c("2", "25"), ]
times <- c(0, 0.6, 2.2, 5.1)
Dt <- c(1, 2, 4)
predSurv <- vector("list", length(times))
for (i in seq_along(times)) {
    set.seed(123)
    dd <- DD[DD$year <= times[i], ]
    predSurv[[i]] <- survfitJM(jointFitBsp.pbc, newdata = dd,
        survTimes = times[i] + Dt, M = 1000)
}

# extract results
f <- function (x, p, v) x$summaries[[p]][, v]
tab <- data.frame(
    time = factor(rep(rep(c(0, 0.5, 2, 5), each = 3), 2)),
    surv = c(sapply(predSurv, f, p = "2", v = "Median"), sapply(predSurv, f, p = "25", v = "Median")),
    low = c(sapply(predSurv, f, p = "2", v = "Lower"), sapply(predSurv, f, p = "25", v = "Lower")),
    up = c(sapply(predSurv, f, p = "2", v = "Upper"), sapply(predSurv, f, p = "25", v = "Upper")),
    patient = gl(2, 12, labels = paste("Patient", c(2, 25))),
    Dt = gl(3, 1, 24, labels = paste("Extra", c(1,2,4), "years"))
)
tab$Dt <- factor(tab$Dt, labels = c("Extra 1 year", levels(tab$Dt)[-1]))

# prepanel and panel functions to construct error bars
prepanel.ci <- function (x, y, lx, ux, subscripts, ...) {
    x <- as.numeric(x)
    lx <- as.numeric(lx[subscripts])
    ux <- as.numeric(ux[subscripts])
    list(xlim = range(x, ux, lx, finite = TRUE))
}
panel.ci <- function (x, y, lx, ux, subscripts, pch = 16, ...) {
    x <- as.numeric(x)
    y <- as.numeric(y)
    lx <- as.numeric(lx[subscripts])
    ux <- as.numeric(ux[subscripts])
    panel.abline(h = c(unique(y)),
        v = seq(0.0, 1, 0.05),
        col = "grey", lty = 2, lwd = 1.5)
    panel.arrows(lx, y, ux, y, col = 1,
        length = 0.1, unit = "native",
        angle = 90, code = 3, lwd = 2)
    panel.xyplot(x, y, pch = pch, col = 1, cex = 1, ...)
}
with(tab, print(dotplot(time ~ surv | Dt*patient, lx = low, ux = up,
    xlab = "Survival Probability", ylab = "Follow-up Time (years)",
    prepanel = prepanel.ci, panel = panel.ci), as.table = TRUE))


###############
# Section 7.2 #
###############

# dynamic prediction for serum bilirubin
ND <- pbc2[pbc2$id == 2, ]
longPreds <- vector("list", nrow(ND))
for (i in 1:nrow(ND)) {
    set.seed(123) # we set the seed for reproducibility
    longPreds[[i]] <- predict(jointFitBsp.pbc, newdata = ND[1:i, ],
        type = "Subject", interval = "confidence", returnData = TRUE)
    longPreds[[i]]$FollowUp <- round(max(ND[1:i, "year"]), 1)
}

# put all results in the same data frame
longPreds.all <- do.call(rbind, longPreds)
longPreds.all$FollowUp <- with(longPreds.all, factor(FollowUp,
    labels = paste("Follow-up time:", unique(FollowUp))))

# plot of dynamic predictions
xyplot(pred + low + upp ~ year | FollowUp, data = longPreds.all,
    panel = function (x, y) {
        xx <- x[seq_len(length(x) / 3)]
        yy <- matrix(y, ncol = 3)
        last.time <- max(x[is.na(y)])
        ind <- xx >= last.time
        ind[which(ind)[1]] <- FALSE
        lpolygon(c(xx[ind], rev(xx[ind])),
            c(yy[ind, 2], rev(yy[ind, 3])), border = "transparent", col = "grey")
        panel.xyplot(xx[ind], yy[ind, 1], lty = 2, lwd = 2,
            type = "l", col = 1)
        panel.xyplot(xx[!ind], yy[!ind, 1], lty = 1, lwd = 2,
            type = "l", col = 1)
        panel.abline(v = last.time, lty = 3, lwd = 2)
    }, as.table = TRUE, xlab = "Time",
    ylab = "Predicted log serum bilirubin", layout = c(3,3))


###############
# Section 7.3 #
###############

lmeFitBsp.pbc <- lme(
    fixed = log(serBilir) ~ bs(year, 4, Boundary.knots = c(0, 15)),
    random = list(
        id = pdDiag(form = ~ bs(year, 4, Boundary.knots = c(0, 15)))),
    data = pbc2)

coxFit.pbc <- coxph(Surv(years, status2) ~ drug + Pro,
    data = pbc2.id, x = TRUE)

# list to calculate the derivative of the B-spline
dform <- list(
    fixed = ~ -1
    + I(3 * bs(year, knots = 2.0534443,
        Boundary.knots = c(0, 15), degree = 2) / 15)
    + I(-3 * bs(year, knots = 2.0534443,
        Boundary.knots = c(0, 15), degree = 2) / 15),

    indFixed = c(3,4,5,2,3,4),

    random = ~ -1
    + I(3 * bs(year, knots = 2.0534443,
        Boundary.knots = c(0, 15), degree = 2) / 15)
    + I(-3 * bs(year, knots = 2.0534443,
        Boundary.knots = c(0, 15), degree = 2) / 15),

    indRandom = c(3,4,5,2,3,4))

# model (I)
jointFitBsp.pbc <- jointModel(lmeFitBsp.pbc, coxFit.pbc,
    timeVar = "year", method = "piecewise-PH-aGH")

# model (II)
jointFitBsp2.pbc <- update(jointFitBsp.pbc,
    parameterization = "slope", derivForm = dform)

# model (III)
jointFitBsp3.pbc <- update(jointFitBsp.pbc,
    parameterization = "both", derivForm = dform)

# model (IV)
jointFitBsp4.pbc <- update(jointFitBsp.pbc,
    interFact = list(value = ~ Pro, data = pbc2.id))

# model (V)
jointFitBsp5.pbc <- update(jointFitBsp2.pbc,
    interFact = list(slope = ~ Pro, data = pbc2.id))

# model (VI)
jointFitBsp6.pbc <- update(jointFitBsp3.pbc,
    interFact = list(value = ~ Pro, slope = ~ Pro, data = pbc2.id))


# Data of Patient 51
ND2 <- pbc2[pbc2$id %in% 51, ]

# Dynamic predictions for the longitudinal outcome under the six joint models
longPreds1 <- longPreds2 <- longPreds3 <- vector("list", nrow(ND2))
longPreds4 <- longPreds5 <- longPreds6 <- vector("list", nrow(ND2))
for (i in 1:nrow(ND2)) {
    set.seed(123)
    nd2 <- ND2[1:i, ]
    time <- max(nd2$year) + 1
    longPreds1[[i]] <- predict(jointFitBsp.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
    longPreds2[[i]] <- predict(jointFitBsp2.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
    longPreds3[[i]] <- predict(jointFitBsp3.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
    longPreds4[[i]] <- predict(jointFitBsp4.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
    longPreds5[[i]] <- predict(jointFitBsp5.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
    longPreds6[[i]] <- predict(jointFitBsp6.pbc, nd2, type = "Subject", FtTimes = time,
        interval = "confidence", M = 500)
}

# Dynamic predictions for the survival outcome under the six joint models
survPreds1 <- survPreds2 <- survPreds3 <- vector("list", nrow(ND2))
survPreds4 <- survPreds5 <- survPreds6 <- vector("list", nrow(ND2))
for (i in 1:nrow(ND2)) {
    set.seed(123)
    nd2 <- ND2[1:i, ]
    time <- max(nd2$year) + 1
    survPreds1[[i]] <- survfitJM(jointFitBsp.pbc, nd2, survTimes = time, M = 500)
    survPreds2[[i]] <- survfitJM(jointFitBsp2.pbc, nd2, survTimes = time, M = 500)
    survPreds3[[i]] <- survfitJM(jointFitBsp3.pbc, nd2, survTimes = time, M = 500)
    survPreds4[[i]] <- survfitJM(jointFitBsp4.pbc, nd2, survTimes = time, M = 500)
    survPreds5[[i]] <- survfitJM(jointFitBsp5.pbc, nd2, survTimes = time, M = 500)
    survPreds6[[i]] <- survfitJM(jointFitBsp6.pbc, nd2, survTimes = time, M = 500)
}


#################
# Section 7.4.6 #
#################


# indicator of baseline measurement
prothro$t0 <- as.numeric(prothro$time == 0)
lmeFitBsp.pro <- lme(pro ~ treat * (time + t0), random = ~ time | id,
    data = prothro)

coxFit.pro <- coxph(Surv(Time, death) ~ treat, data = prothros,
    x = TRUE)

jointFitBsp.pro <- jointModel(lmeFitBsp.pro, coxFit.pro,
    timeVar = "time", method = "piecewise-PH-aGH")

summary(jointFitBsp.pro)

# data on which to base the ROC calculations
plcbData <- data.frame(
    id = 1,
    treat = factor("placebo", levels = levels(prothro$treat)),
    time = c(0, 0.25, 1, 3, 4)
)
plcbData$t0 <- as.numeric(plcbData$time == 0)

plcbData

# ROC estimation for Dt = (1, 2, 4) based on the standard
# prediction rule
set.seed(123) # we set the seed for reproducibility
ROCplcb <- rocJM(jointFitBsp.pro, dt = c(1, 2, 4), data = plcbData,
    M = 1000, burn.in = 500)

# estimated AUCs and optimal thresholds
ROCplcb

# ROC curves
plot(ROCplcb, legend = TRUE)


# time-dependent ROCs
ROCs <- vector("list", 5)
for (i in seq_along(ROCs)) {
    set.seed(123) # we set the seed for reproducibility
    ROCs[[i]] <- rocJM(jointFitBsp.pro, dt = c(1, 2, 4),
        data = plcbData[seq_len(i), ], M = 1000, burn.in = 500)
}

# plot of time-dependent ROCs
par(mfrow = c(2, 2), oma = c(0, 0, 2, 0))
for (i in 2:5) {
    plot(ROCs[[i]], legend = TRUE)
}
mtext("Prediction rule: Simple", side = 3, line = -1, outer = TRUE)


# ROCs based on the composite rule assuming a 20% decrease
# in prothrombin
set.seed(123)
ROCplcb.Rel <- rocJM(jointFitBsp.pro, dt = c(1, 2, 4), data = plcbData,
    diffType = "relative", rel.diff = c(1, 0.8), M = 1000, burn.in = 500)

# estimated AUCs and optimal thresholds under the composite rule
ROCplcb.Rel

# ROC curves under the composite rule
plot(ROCplcb.Rel, legend = TRUE)

# time-dependent ROCs under the composite rule
ROCs.r <- vector("list", 5)
for (i in seq_along(ROCs)) {
    set.seed(123)
    ROCs.r[[i]] <- rocJM(jointFitBsp.pro, dt = c(1, 2, 4),
        data = plcbData[seq_len(i), ], diffType = "relative",
        rel.diff = c(1, 0.8), M = 1000, burn.in = 500)
}

# plot of time-dependent ROCs under the composite rule
par(mfrow = c(2, 2), oma = c(0, 0, 2, 0))
for (i in 2:5) {
    plot(ROCs.r[[i]], legend = TRUE, main = "Prediction rule: Simple")
}
mtext("Prediction rule: Composite", side = 3, line = -1, outer = TRUE)


# sensitivity to the parameterization
# formulas to calculate the time-dependent slope term
dform2 <- list(fixed = ~ treat, indFixed = c(3, 5),
    random = ~ 1, indRandom = 2)

jointFitBsp2.pro <- update(jointFitBsp.pro,
    parameterization = "slope", derivForm = dform2)

jointFitBsp3.pro <- update(jointFitBsp.pro,
    parameterization = "both", derivForm = dform2)


# ROCs under the composite rule for Model (II)
set.seed(123)
ROCplcb.Rel2 <- rocJM(jointFitBsp2.pro, dt = c(1, 2, 4),
    data = plcbData, directionSmaller = TRUE, diffType = "relative",
    rel.diff = c(1, 0.8), M = 1000, burn.in = 500)

ROCplcb.Rel2


# ROCs under the composite rule for Model (III)
set.seed(123)
ROCplcb.Rel3 <- rocJM(jointFitBsp3.pro, dt = c(1, 2, 4),
    data = plcbData, diffType = "relative", rel.diff = c(1, 0.8),
    M = 1000, burn.in = 500)

ROCplcb.Rel3


# data to calculate the discrimination index
plcbData2 <- data.frame(
    id = 1,
    treat = factor("placebo", levels = levels(prothro$treat)),
    time = c(0, 0.25, 0.5, 1:10)
)
plcbData2$t0 <- as.numeric(plcbData2$time == 0)

# Note: The following piece of code takes some time to execute
ROCs.MI <- ROCs.MII <- ROCs.MIII <- vector("list", nrow(plcbData2))
for (i in seq_along(ROCs.MI)) {
    set.seed(123)
    ROCs.MI[[i]] <- rocJM(jointFitBsp.pro, dt = 4,
        data = plcbData2[seq_len(i), ], diffType = "relative",
        rel.diff = c(1, 0.8), M = 1000, burn.in = 500)

    ROCs.MII[[i]] <- rocJM(jointFitBsp2.pro, dt = 4,
        data = plcbData2[seq_len(i), ], directionSmaller = TRUE,
        diffType = "relative", rel.diff = c(1, 0.8),
        M = 1000, burn.in = 500)

    ROCs.MIII[[i]] <- rocJM(jointFitBsp3.pro, dt = 4,
        data = plcbData2[seq_len(i), ], diffType = "relative",
        rel.diff = c(1, 0.8), M = 1000, burn.in = 500)
}


# extract AUCs
AUCs.MI <- sapply(ROCs.MI, "[[", "AUCs")
AUCs.MII <- sapply(ROCs.MII, "[[", "AUCs")
AUCs.MIII <- sapply(ROCs.MIII, "[[", "AUCs")

# calculate survival probabilities using the Kaplan-Meier
sf <- survfit(Surv(Time, death) ~ treat, data = prothros,
    subset = treat == "placebo")
Surv.Plcb <- summary(sf, times = c(0, 0.25, 0.5, 1:10))$surv


# code to calculate the integrals using the trapezoidal rule
times <- c(0, 0.25, 0.5, 1:10)
a <- head(times, -1)
b <- tail(times, -1)
rbind(a, b)

# denominator
Denom <- sum((b - a) * (head(Surv.Plcb, -1) + tail(Surv.Plcb, -1)) / 2)

# numerators
Numer.MI <- sum((b - a) * (head(AUCs.MI * Surv.Plcb, -1) +
    tail(AUCs.MI * Surv.Plcb, -1)) / 2)
Numer.MII <- sum((b - a) * (head(AUCs.MII * Surv.Plcb, -1) +
    tail(AUCs.MII * Surv.Plcb, -1)) / 2)
Numer.MIII <- sum((b - a) * (head(AUCs.MIII * Surv.Plcb, -1) +
    tail(AUCs.MIII * Surv.Plcb, -1)) / 2)

# the dynamic discrimination index
Numer.MI / Denom # Cdyn Model (I)
Numer.MII / Denom # Cdyn Model (II)
Numer.MIII / Denom # Cdyn Model (III)