Skip to content

Instantly share code, notes, and snippets.

@nicolashery
Last active December 13, 2023 21:43
Show Gist options
  • Select an option

  • Save nicolashery/4dcf7003564c576d0d2f4872447c7b02 to your computer and use it in GitHub Desktop.

Select an option

Save nicolashery/4dcf7003564c576d0d2f4872447c7b02 to your computer and use it in GitHub Desktop.

Nesting APIs and ReaderT environments in Haskell's Servant

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main (main) where
import Relude hiding (traceId)
import Control.Exception (throwIO, try)
import Control.Monad.Logger (
Loc,
LogLevel,
LogSource,
LogStr,
MonadLogger (..),
ToLogStr (toLogStr),
)
import Control.Monad.Logger.Aeson (
Message ((:#)),
logDebug,
logInfo,
runLoggingT,
(.=),
)
import Control.Monad.Logger.Aeson qualified as Logger (defaultOutput)
import Data.Pool (Pool, defaultPoolConfig, newPool, withResource)
import Network.HTTP.Client (
Manager,
defaultManagerSettings,
managerConnCount,
newManager,
)
import Network.Wai.Handler.Warp qualified as Warp
import Servant (
NamedRoutes,
ServerError (..),
err401,
err404,
err500,
hoistServer,
serve,
)
import Servant qualified (Handler (..))
import Servant.API (
Capture,
GenericMode ((:-)),
Get,
GetNoContent,
Header,
NoContent (..),
PlainText,
Post,
ReqBody,
(:>),
)
import Servant.Server qualified as Servant (layout)
import Servant.Server.Internal (AsServerT)
-- API
-- ----------------------------------------------------------------------------
type AuthorizationHeader = Text
type TraceParentHeader = Text
type OrganizationId = Text
type ProjectId = Text
type TicketId = Text
type ListOrganizationsResponse = Text
type LayoutResponse = Text
type CreateProjectRequest = Text
type CreateProjectResponse = Text
type GetProjectResponse = Text
type CreateTicketRequest = Text
type CreateTicketResponse = Text
type GetTicketResponse = Text
type Api =
"v1"
:> Header "traceparent" TraceParentHeader
:> NamedRoutes RootApi
data RootApi mode = RootApi
{ health
:: mode
:- "health"
:> GetNoContent
, layout
:: mode
:- "layout"
:> Get '[PlainText] LayoutResponse
, authenticatedApi
:: mode
:- Header "Authorization" AuthorizationHeader
:> NamedRoutes AuthenticatedApi
}
deriving stock (Generic)
data AuthenticatedApi mode = AuthenticatedApi
{ listOrganizations
:: mode
:- "organizations"
:> Get '[PlainText] ListOrganizationsResponse
, projectApi
:: mode
:- "organizations"
:> Capture "organizationId" OrganizationId
:> "projects"
:> NamedRoutes ProjectApi
}
deriving stock (Generic)
data ProjectApi mode = ProjectApi
{ createProject
:: mode
:- ReqBody '[PlainText] CreateProjectRequest
:> Post '[PlainText] CreateProjectResponse
, getProject
:: mode
:- Capture "projectId" ProjectId
:> Get '[PlainText] GetProjectResponse
, ticketApi
:: mode
:- Capture "projectId" ProjectId
:> "tickets"
:> NamedRoutes TicketApi
}
deriving stock (Generic)
data TicketApi mode = TicketApi
{ createTicket
:: mode
:- ReqBody '[PlainText] CreateTicketRequest
:> Post '[PlainText] CreateTicketResponse
, getTicket
:: mode
:- Capture "ticketId" TicketId
:> Get '[PlainText] GetTicketResponse
}
deriving stock (Generic)
-- Logging
-- ----------------------------------------------------------------------------
type LogFunc =
Loc -> LogSource -> LogLevel -> LogStr -> IO ()
class HasLogFunc env where
getLogFunc :: env -> LogFunc
-- Fake database
-- ----------------------------------------------------------------------------
data Connection = Connection
createDbPool :: Text -> Int -> IO (Pool Connection)
createDbPool _databaseUrl poolSize = do
newPool $
defaultPoolConfig
create
destroy
poolTtl
poolSize
where
create = pure Connection
destroy = const $ pure ()
poolTtl = 10
data DatabaseEnv = DatabaseEnv
{ dbLogger :: LogFunc
, connectionPool :: Pool Connection
}
class HasDatabase env where
getDatabase :: env -> DatabaseEnv
newtype Database a = Database
{ unDatabase :: ReaderT DatabaseEnv IO a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadIO
, MonadReader DatabaseEnv
)
runDatabaseIO :: DatabaseEnv -> Database a -> IO a
runDatabaseIO env action =
runReaderT (unDatabase action) env
type MonadDatabase env m = (MonadReader env m, HasDatabase env)
runDatabase
:: (MonadDatabase env m, MonadIO m)
=> Database a
-> m a
runDatabase action = do
env <- asks getDatabase
liftIO $ runDatabaseIO env action
query :: (Show p) => Text -> p -> Database [r]
query q parameters = do
logger <- asks dbLogger
void . flip runLoggingT logger . logDebug $
"Database.query"
:# [ "query" .= q
, "parameters" .= (show parameters :: Text)
]
withConnection $ const (pure [])
withConnection :: (Connection -> IO a) -> Database a
withConnection action = do
pool <- asks connectionPool
liftIO $ withResource pool action
-- Fake authentication
-- ----------------------------------------------------------------------------
type UserId = Text
parseAuthHeader :: Maybe AuthorizationHeader -> Either Text UserId
parseAuthHeader Nothing = Left "Missing 'Authorization' header"
parseAuthHeader _ = Right "d42ed530-adba-41f0-99af-60bd6c476617"
authenticateUser
:: (MonadIO m)
=> Text
-> Maybe AuthorizationHeader
-> m UserId
authenticateUser _authKey maybeAuthHeader =
case parseAuthHeader maybeAuthHeader of
Left _ ->
liftIO . throwIO $
err401
{ errBody = "Missing or invalid 'Authorization' header"
}
Right userId -> pure userId
data AuthEnv = AuthEnv
{ userId :: UserId
}
class HasAuth env where
getAuth :: env -> AuthEnv
type MonadAuth env m = (MonadReader env m, HasAuth env)
getUserId :: (MonadAuth env m) => m Text
getUserId = userId <$> asks getAuth
-- Fake tracing
-- ----------------------------------------------------------------------------
data Tracer = Tracer
data Span = Span
data TracingEnv = TracingEnv
{ tracer :: Tracer
, activeSpan :: IORef Span
}
class HasTracing env where
getTracing :: env -> TracingEnv
type MonadTracing env m = (MonadReader env m, HasTracing env)
createTracer :: (MonadIO m) => Text -> m Tracer
createTracer _ = pure Tracer
createNewSpan :: (MonadIO m) => Maybe TraceParentHeader -> m Span
createNewSpan _ = pure Span
childSpan :: (MonadIO m) => IORef Span -> Text -> m ()
childSpan activeSpan _childSpanName =
atomicModifyIORef activeSpan ((,()) . identity)
traced :: (MonadTracing env m, MonadIO m) => Text -> m a -> m a
traced spanName action = do
activeSpan <- activeSpan <$> asks getTracing
childSpan activeSpan spanName
action
-- Fake organization service client
-- ----------------------------------------------------------------------------
data Organization = Organization
{ organizationId :: OrganizationId
, name :: Text
}
data OrganizationService = OrganizationService
{ fetchUserOrganizations :: UserId -> IO [Organization]
, fetchOrganization :: OrganizationId -> IO Organization
}
createOrganizationServiceClient :: Manager -> Text -> OrganizationService
createOrganizationServiceClient _httpManager _serviceBaseUrl =
OrganizationService
{ fetchUserOrganizations =
\_userId ->
pure
[ Organization
{ organizationId = "90ee1361-ee8b-4b22-be38-14bf46a28cfd"
, name = "Org 1"
}
, Organization
{ organizationId = "6e0549c0-15da-4262-9046-4357413c2791"
, name = "Org 2"
}
]
, fetchOrganization = \organizationId ->
pure
Organization
{ organizationId = organizationId
, name = "Org 1"
}
}
-- App (Root)
-- ----------------------------------------------------------------------------
data Dependencies = Dependencies
{ dbPool :: Pool Connection
, depsLogger :: LogFunc
, tracer :: Tracer
, authKey :: Text
, organizationService :: OrganizationService
}
data AppEnv = AppEnv
{ appLogger :: LogFunc
, databaseEnv :: DatabaseEnv
, tracingEnv :: TracingEnv
}
newtype App a = App
{ unApp :: ReaderT AppEnv IO a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadIO
, MonadReader AppEnv
)
class HasApp env where
getApp :: env -> AppEnv
instance HasApp AppEnv where
getApp = identity
instance (HasApp env) => HasLogFunc env where
getLogFunc = appLogger . getApp
instance MonadLogger App where
monadLoggerLog loc logSource logLevel msg = do
logger <- asks getLogFunc
liftIO $ logger loc logSource logLevel (toLogStr msg)
instance (HasApp env) => HasDatabase env where
getDatabase = databaseEnv . getApp
instance (HasApp env) => HasTracing env where
getTracing = tracingEnv . getApp
runAppIO :: AppEnv -> App a -> IO a
runAppIO appEnv action = runReaderT (unApp action) appEnv
runAppServant
:: AppEnv
-> App a
-> Servant.Handler a
runAppServant appEnv action =
Servant.Handler . ExceptT . try $ runAppIO appEnv action
runApp :: Dependencies -> Maybe TraceParentHeader -> App a -> Servant.Handler a
runApp
Dependencies {dbPool, depsLogger, tracer}
maybeTraceParentHeader
action = do
activeSpan <- createNewSpan maybeTraceParentHeader >>= newIORef
let tracingEnv =
TracingEnv
{ tracer = tracer
, activeSpan = activeSpan
}
databaseEnv =
DatabaseEnv
{ dbLogger = depsLogger
, connectionPool = dbPool
}
appEnv =
AppEnv
{ appLogger = depsLogger
, databaseEnv = databaseEnv
, tracingEnv = tracingEnv
}
runAppServant appEnv action
server
:: Dependencies
-> Maybe TraceParentHeader
-> RootApi (AsServerT Servant.Handler)
server deps maybeTraceParentHeader =
hoistServer
(Proxy @(NamedRoutes RootApi))
(runApp deps maybeTraceParentHeader)
(rootServer deps)
rootServer :: Dependencies -> RootApi (AsServerT App)
rootServer deps =
RootApi
{ health = healthHandler
, layout = layoutHandler
, authenticatedApi = authenticatedServer'
}
where
authenticatedServer' maybeAuthHeader =
hoistServer
(Proxy @(NamedRoutes AuthenticatedApi))
(runAppAuthenticated (getDependenciesAuthenticated deps) maybeAuthHeader)
(authenticatedServer maybeAuthHeader)
getDependenciesAuthenticated :: Dependencies -> DependenciesAuthenticated
getDependenciesAuthenticated Dependencies {authKey, organizationService} =
DependenciesAuthenticated
{ authKey = authKey
, organizationService = organizationService
}
healthHandler :: App NoContent
healthHandler = pure NoContent
layoutHandler :: App Text
layoutHandler = pure $ Servant.layout (Proxy @Api)
-- AppAuthenticated
-- ----------------------------------------------------------------------------
data DependenciesAuthenticated = DependenciesAuthenticated
{ authKey :: Text
, organizationService :: OrganizationService
}
data AppAuthenticatedEnv = AppAuthenticatedEnv
{ appEnv :: AppEnv
, authEnv :: AuthEnv
, appOrganizationService :: OrganizationService
}
newtype AppAuthenticated a = AppAuthenticated
{ unAppAuthenticated :: ReaderT AppAuthenticatedEnv IO a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadIO
, MonadReader AppAuthenticatedEnv
)
class (HasApp env) => HasAppAuthenticated env where
getAppAuthenticated :: env -> AppAuthenticatedEnv
instance HasApp AppAuthenticatedEnv where
getApp = appEnv
instance HasAppAuthenticated AppAuthenticatedEnv where
getAppAuthenticated = identity
instance MonadLogger AppAuthenticated where
monadLoggerLog loc logSource logLevel msg = do
logger <- asks getLogFunc
liftIO $ logger loc logSource logLevel (toLogStr msg)
instance (HasAppAuthenticated env) => HasAuth env where
getAuth = authEnv . getAppAuthenticated
runAppAuthenticated
:: DependenciesAuthenticated
-> Maybe AuthorizationHeader
-> AppAuthenticated a
-> App a
runAppAuthenticated
DependenciesAuthenticated {authKey, organizationService}
maybeAuthHeader
action = do
userId <- authenticateUser authKey maybeAuthHeader
let authEnv =
AuthEnv
{ userId = userId
}
mapEnv appEnv' =
AppAuthenticatedEnv
{ appEnv = appEnv'
, authEnv = authEnv
, appOrganizationService = organizationService
}
App $ withReaderT mapEnv (unAppAuthenticated action)
listOrganizationsHandler :: AppAuthenticated ListOrganizationsResponse
listOrganizationsHandler = traced "list_organizations" $ do
userId <- getUserId
organizationService <- asks appOrganizationService
organizations <- liftIO $ fetchUserOrganizations organizationService userId
logInfo $
"fetched organizations"
:# [ "user_id" .= userId
, "organizations" .= map organizationId organizations
]
liftIO $ throwIO $ err500 {errBody = "Not implemented"}
authenticatedServer
:: Maybe AuthorizationHeader
-> AuthenticatedApi (AsServerT AppAuthenticated)
authenticatedServer _maybeAuthHeader =
AuthenticatedApi
{ listOrganizations = listOrganizationsHandler
, projectApi = projectServer'
}
where
projectServer' organizationId =
hoistServer
(Proxy @(NamedRoutes ProjectApi))
(runAppProject organizationId)
(projectServer organizationId)
-- AppProject
-- ----------------------------------------------------------------------------
data Project = Project
{ projectId :: ProjectId
, name :: Text
}
data AppProjectEnv = AppProjectEnv
{ appAuthenticatedEnv :: AppAuthenticatedEnv
, projectOrganization :: Organization
}
newtype AppProject a = AppProject
{ unAppProject :: ReaderT AppProjectEnv IO a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadIO
, MonadReader AppProjectEnv
)
class (HasAppAuthenticated env) => HasAppProject env where
getAppProject :: env -> AppProjectEnv
instance HasApp AppProjectEnv where
getApp = appEnv . appAuthenticatedEnv
instance HasAppAuthenticated AppProjectEnv where
getAppAuthenticated = appAuthenticatedEnv
instance HasAppProject AppProjectEnv where
getAppProject = identity
instance MonadLogger AppProject where
monadLoggerLog loc logSource logLevel msg = do
logger <- asks getLogFunc
liftIO $ logger loc logSource logLevel (toLogStr msg)
runAppProject
:: OrganizationId
-> AppProject a
-> AppAuthenticated a
runAppProject organizationId action = do
organizationService <- asks appOrganizationService
projectOrganization <-
liftIO $
fetchOrganization organizationService organizationId
let mapEnv appAuthenticatedEnv' =
AppProjectEnv
{ appAuthenticatedEnv = appAuthenticatedEnv'
, projectOrganization = projectOrganization
}
AppAuthenticated $ withReaderT mapEnv (unAppProject action)
projectServer :: OrganizationId -> ProjectApi (AsServerT AppProject)
projectServer _organizationId =
ProjectApi
{ createProject = createProjectHandler
, getProject = getProjectHandler
, ticketApi = ticketServer'
}
where
ticketServer' projectId =
hoistServer
(Proxy @(NamedRoutes TicketApi))
(runAppTicket projectId)
(ticketServer projectId)
createProjectHandler :: CreateProjectRequest -> AppProject CreateProjectResponse
createProjectHandler projectName = traced "create_project" $ do
userId <- getUserId
organizationId <- organizationId <$> getProjectOrganization
_ <-
runDatabase $
query
"insert into projects (name, organization_id) values (?, ?) returning id"
(projectName, organizationId)
logInfo $
"created project"
:# [ "user_id" .= userId
, "organization_id" .= organizationId
]
liftIO $ throwIO $ err500 {errBody = "Not implemented"}
getProjectHandler :: ProjectId -> AppProject GetProjectResponse
getProjectHandler projectId = traced "get_project" $ do
userId <- getUserId
organizationId <- organizationId <$> getProjectOrganization
_ <- runDatabase $ findProjectById projectId
logInfo $
"fetched project"
:# [ "user_id" .= userId
, "organization_id" .= organizationId
]
liftIO $ throwIO $ err500 {errBody = "Not implemented"}
getProjectOrganization
:: (MonadReader env m, HasAppProject env) => m Organization
getProjectOrganization =
asks (projectOrganization . getAppProject)
findProjectById :: ProjectId -> Database (Maybe Project)
findProjectById projectId = do
_ <-
query
"select id, name from projects where id = ?"
projectId
pure . Just $
Project
{ projectId = projectId
, name = "My project"
}
-- AppTicket
-- ----------------------------------------------------------------------------
data AppTicketEnv = AppTicketEnv
{ appProjectEnv :: AppProjectEnv
, ticketProject :: Project
}
newtype AppTicket a = AppTicket
{ unAppTicket :: ReaderT AppTicketEnv IO a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadIO
, MonadReader AppTicketEnv
)
class (HasAppProject env) => HasAppTicket env where
getAppTicket :: env -> AppTicketEnv
instance HasApp AppTicketEnv where
getApp = appEnv . appAuthenticatedEnv . appProjectEnv
instance HasAppAuthenticated AppTicketEnv where
getAppAuthenticated = appAuthenticatedEnv . appProjectEnv
instance HasAppProject AppTicketEnv where
getAppProject = appProjectEnv
instance HasAppTicket AppTicketEnv where
getAppTicket = identity
instance MonadLogger AppTicket where
monadLoggerLog loc logSource logLevel msg = do
logger <- asks getLogFunc
liftIO $ logger loc logSource logLevel (toLogStr msg)
runAppTicket
:: ProjectId
-> AppTicket a
-> AppProject a
runAppTicket projectId action = do
let projectNotFound :: AppProject Project
projectNotFound =
liftIO $ throwIO $ err404 {errBody = "Project not found"}
maybeProject <- runDatabase (findProjectById projectId)
project <- maybe projectNotFound pure maybeProject
let mapEnv appProjectEnv' =
AppTicketEnv
{ appProjectEnv = appProjectEnv'
, ticketProject = project
}
AppProject $ withReaderT mapEnv (unAppTicket action)
ticketServer :: ProjectId -> TicketApi (AsServerT AppTicket)
ticketServer _projectId =
TicketApi
{ createTicket = createTicketHandler
, getTicket = getTicketHandler
}
createTicketHandler :: CreateTicketRequest -> AppTicket CreateTicketResponse
createTicketHandler ticketName = traced "create_ticket" $ do
userId <- getUserId
organizationId <- organizationId <$> getProjectOrganization
projectId <- projectId <$> getTicketProject
_ <-
runDatabase $
query
"insert into tickets (name, project_id) values (?, ?) returning id"
(ticketName, projectId)
logInfo $
"created ticket"
:# [ "user_id" .= userId
, "organization_id" .= organizationId
, "project_id" .= projectId
]
liftIO $ throwIO $ err500 {errBody = "Not implemented"}
getTicketHandler :: TicketId -> AppTicket GetTicketResponse
getTicketHandler ticketId = traced "get_ticket" $ do
userId <- getUserId
organizationId <- organizationId <$> getProjectOrganization
projectId <- projectId <$> getTicketProject
_ <-
runDatabase $
query
"select id, name from tickets where id = ?"
ticketId
logInfo $
"fetched ticket"
:# [ "user_id" .= userId
, "organization_id" .= organizationId
, "project_id" .= projectId
]
liftIO $ throwIO $ err500 {errBody = "Not implemented"}
getTicketProject
:: (MonadReader env m, HasAppTicket env) => m Project
getTicketProject =
asks (ticketProject . getAppTicket)
-- Main
-- ----------------------------------------------------------------------------
main :: IO ()
main = do
authKey <- toText . fromMaybe "abc123" <$> lookupEnv "AUTH_KEY"
projectServiceUrl <-
toText . fromMaybe "http://localhost:3001"
<$> lookupEnv "PROJECT_SERVICE_URL"
dbPool <- createDbPool "app:app@localhost:5432/app" 10
tracer <- createTracer "app"
httpManager <-
newManager $
defaultManagerSettings {managerConnCount = 20}
let port = 3000
dependencies =
Dependencies
{ dbPool = dbPool
, depsLogger = Logger.defaultOutput stdout
, tracer = tracer
, authKey = authKey
, organizationService =
createOrganizationServiceClient
httpManager
projectServiceUrl
}
waiApp = serve (Proxy @Api) (server dependencies)
Warp.run port waiApp
cabal-version: 3.0
name: servant-nested-apis
version: 1.0.0
common options
build-depends:
, base
, http-client
, monad-logger
, monad-logger-aeson
, relude
, relude
, resource-pool
, servant
, servant-server
, warp
ghc-options:
-Wall
-Wcompat
-Widentities
-Wincomplete-uni-patterns
-Wincomplete-record-updates
-Wredundant-constraints
-Wmissing-export-lists
-Wpartial-fields
-Wunused-packages
default-language: GHC2021
default-extensions:
DeriveAnyClass
DerivingStrategies
DerivingVia
DuplicateRecordFields
NoImplicitPrelude
OverloadedRecordDot
OverloadedStrings
StrictData
executable servant-nested-apis
import: options
main-is: Main.hs
hs-source-dirs: .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment