library(simcausal)

#### Custom Distribution Functions ####
rnorm_trunc <- function(n, mean, sd, minval = 0) 
{ 
  out <- rnorm(n = n, mean = mean, sd = sd) 
  minval <- minval[1] 
  out[out < minval] <- minval 
  out 
}


rlnorm_group <- function(n, mean, sd, groups) 
{ 
  m <- length(unique(groups))
  hold <- rlnorm(m, mean = mean, sd = sd)
  out <- numeric(n)
  for(i in 1:n){
    out[which(groups == i)] <- hold[i]
  }
  return(out)
}

rexp_age <- function(n, mean)
{
  hold <- rexp(n, 1/mean)/10
  hold[hold >10] <-10
  out  <- hold
  return(out)
}

rnorm_group <- function(n, mean, sd, groups)
{
  m <- length(unique(groups))
  hold <- rnorm(m, mean = mean, sd = sd)
  out <- numeric(n)
  for(i in 1:n){
    out[which(groups == i)] <- hold[i]
  }
  return(out)
}

#### Creating DAG ####
D <- DAG.empty()
D <- D + node('group',
              distr = 'rcategor.int',
              probs = rep(1/4, 4)) +
         node('grpdist',
              distr = 'rlnorm_group',
              mean = 0,
              sd    = 0.75,
              groups = group)  +
         node('X2',
              distr  = 'rnorm_trunc',
              mean   = grpdist,
              sd     = 0.05,
              minval = 0) + 
         node('X1',
              distr = 'rexp_age',
              mean  = 20) +
         node('b',
              distr  = 'rnorm_group',
              mean   = 0,
              sd     = sqrt(1.0859),
              groups = group) + 
         node('B',
              distr = 'rbern',
              prob  = plogis(0.2727 - 0.0387 * X1 + 0.2179 * X2 + b)) +
         node('A',
             distr = 'rbern',
             prob  = ifelse(B == 0, 0, 2/3))

D1 <- set.DAG(D)

#### Simulate Data ####

dt <- simobs(D1, n = 10000)
head(dt)

