### Adaptive Rejection Sampling # by Will Townes rexp_trunc<-function(n,slope=-1,lo=0,hi=Inf){ #draw n samples from the truncated exponential distribution #the distribution is proportional to exp(slope*x) #default is standard exponential #slope cannot equal zero #lo is lower truncation point, can be -Inf #hi is upper truncation point, can be +Inf u<-runif(n) if(lo== -Inf){ stopifnot(slope>0 && hi -Inf) return(lo+log(u)/slope) } else { stopifnot(slope != 0) } lo+log1p(u*expm1(slope*(hi-lo)))/slope } ars_wt_formula<-function(a,b,c,d,numeric_zero=1e-8){ #compute formula for normalization constant of a sub-interval #(1/a)exp(b)(exp(ad)-exp(ac)) #require that c <= d #a is slope of line, b is intercept #c, d are left, right endpoints of the interval dmc<-d-c stopifnot(dmc>=0) if(abs(a)1) w<-vector("numeric",2*nIvl+2) #indexed by j, all sub-intervals w[1]<-ars_wt_formula(a[1],b[1],lo,xpts[1]) #left interval, lo can be -Inf w[2*nIvl+2]<-ars_wt_formula(a[nIvl],b[nIvl],xpts[nIvl+1],hi) #right int., hi can be +Inf w[2:(2*nIvl+1)]<-unlist(lapply(idx,calc_wts_inner,xpts,xstar,a,b)) #un-normalized weights w } subinterval_to_interval<-function(j){floor(j/2)} get_xstar<-function(xpts,a,b,idx=seq_along(a)){ #provide grid points xpts, slopes "a" and intercepts "b" for each line segment #returns the breakpoints for the envelope function "xstar" nIvl<-length(idx) xstar<-xpts[idx] #initialize, edge case on left xstar[nIvl]<-xpts[nIvl+1] #edge case on right if(nIvl>2) xstar[2:(nIvl-1)]<- -diff(b,2)/diff(a,2) #handle annoying edge-cases that occur for almost-linear regions of function #due to rounding errors, can lead to negative weights if not addressed too_low<-which(xstar < xpts[idx]) xstar[too_low]<-xpts[too_low] too_hi<- which(xstar > xpts[idx+1]) xstar[too_hi]<-xpts[too_hi+1] xstar } ars<-function(func,nSample=1,xpts,lo=0,hi=1,logscale=TRUE,verbose=TRUE){ #sample nSample times from univariate function f #f(x) must be log-concave. #If logscale==TRUE, assume log(f(x)) is provided instead of f(x) #xpts are a grid of points to construct envelope function, must be >= 3 points! #xpts must be in region of positive probability for f(x) #lo,hi are lower,upper bounds of domain of integration of f, may be -Inf or +Inf h<-if(logscale) func else function(x){log(func(x))} xpts<-sort(xpts) nIvl<-length(xpts)-1 ypts<-h(xpts) stopifnot(all(ypts>-Inf)) a<-diff(ypts)/diff(xpts) idx<-seq.int(nIvl) #index intervals by left endpoint #one fewer intervals than points b<-ypts[idx]-a*xpts[idx] #compute "breakpoints" within each interval xstar<-get_xstar(xpts,a,b,idx) #handle problem values when outer intervals have infinities #if(is.nan(xstar[2])) xstar[2]<-xpts[2] #if(is.nan(xstar[nIvl-1])) xstar[nIvl-1]<-xpts[nIvl] #compute weights for each sub-interval w<-calc_wts(xpts,xstar,a,b,lo,hi,idx) #each element of w is a sub interval #w has length 2*nIvl+2 (2 sub-intervals per interval, and a left and right outside intervals) #w<-w/sum(w) #normalization res<-rep(NA,nSample) nRes<-0 while(nRes < nSample){ #choose index from multinomial probs j<-sample.int(2*nIvl+2,1,prob=w) #index of a subinterval i<-subinterval_to_interval(j) #between zero and nIvl+1, inclusive is_right<- as.logical(j%%2) if(is_right){ #c for "current" c_slope<-a[i+1]; c_hi<-xpts[i+1]; c_icpt<-b[i+1] c_lo<-if(i>0){ xstar[i] }else{ lo } } else { #case of left sub-interval c_slope<-a[i-1]; c_lo<-xpts[i]; c_icpt<-b[i-1] c_hi<-if(i<=nIvl){ c_hi<-xstar[i] }else{ hi } } x<-rexp_trunc(1,slope=c_slope,lo=c_lo,hi=c_hi) #x is a valid sample from the upper envelope function #next, do the accept/reject step hx<-h(x) #adding rexp(1) equiv to subtracting log(uniform(0,1)) accpt<- (c_slope*x+c_icpt <= hx + rexp(1)) if(accpt){ #accepted sample nRes<-nRes+1 res[nRes]<-x } else { #rejected sample #print("rejected!") #insert x into xpts and update all statistics nIvl<-nIvl+1; idx<-append(idx,nIvl) #max possible j is (nIvl-1) xpts<-append(xpts,x,i) #preserves ordering ypts<-append(ypts,hx,i) a<-append(a,NA,i); b<-append(b,NA,i) if(i0){ #includes possibly outer right interval i=nIvl a[i]<-(ypts[i+1]-ypts[i])/(xpts[i+1]-xpts[i]) b[i]<-ypts[i]-a[i]*xpts[i] } #to do: make more efficient by only updating parts of xstar, w that change xstar<-get_xstar(xpts,a,b,idx) #debugging #if(any(xstarxpts[idx+1])){ # print(paste0("xstar=",xstar)) # print(paste0("xpts=",xpts)) #} w<-calc_wts(xpts,xstar,a,b,lo,hi,idx) } } #if(verbose) print(signif(xpts,2)) res }