Skip to content

Instantly share code, notes, and snippets.

@sagoez
Last active May 5, 2024 19:52
Show Gist options
  • Select an option

  • Save sagoez/5b5a805050b694cc4092cd067c9c5048 to your computer and use it in GitHub Desktop.

Select an option

Save sagoez/5b5a805050b694cc4092cd067c9c5048 to your computer and use it in GitHub Desktop.

Revisions

  1. sagoez renamed this gist May 5, 2024. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. sagoez created this gist May 5, 2024.
    284 changes: 284 additions & 0 deletions cache.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,284 @@
    use crate::{
    domain::{Error, Unit},
    service::Cache as CacheConfig,
    };
    use async_trait::async_trait;
    use deadpool_redis::{
    redis::{
    cmd, from_redis_value, AsyncCommands, FromRedisValue, JsonAsyncCommands, RedisError,
    RedisResult, RedisWrite, ToRedisArgs,
    },
    Connection, Pool, PoolConfig, Runtime,
    };
    use futures::Future;
    use secrecy::ExposeSecret;
    use serde::{Deserialize, Serialize};
    use serde_json::Value;

    #[async_trait]
    pub trait CacheExt {
    async fn get_or_insert_with<F, Fut, T>(
    &self,
    key: &str,
    f: F,
    expire: Option<u64>,
    ) -> Result<T, Error>
    where
    T: for<'de> Deserialize<'de> + Serialize + Send + Clone,
    F: FnOnce() -> Fut + Send,
    Fut: Future<Output = Result<T, Error>> + Send;

    async fn get<T>(&self, key: &str) -> Result<Option<T>, Error>
    where
    T: for<'de> Deserialize<'de>;
    async fn set<T>(&self, key: &str, value: T, expire: Option<u64>) -> Result<Unit, Error>
    where
    T: Serialize + Send;
    async fn remove(&self, key: &str) -> Result<Unit, Error>;
    async fn clear(&self) -> Result<Unit, Error>;
    }

    #[derive(Debug, Clone)]
    struct CacheEntry {
    key: String,
    value: Option<ValueWrapper>,
    }

    impl CacheEntry {
    pub fn new(key: String, value: Value) -> Self {
    Self {
    key,
    value: Some(ValueWrapper(value)),
    }
    }

    fn get_as<T>(&self) -> Option<T>
    where
    T: for<'de> Deserialize<'de>,
    {
    match &self.value {
    Some(value) => serde_json::from_value(value.0.clone()).ok(),
    None => None,
    }
    }
    }

    #[derive(Debug, Clone, Serialize, Deserialize)]
    #[serde(transparent)]
    pub struct ValueWrapper(Value);

    impl ToRedisArgs for ValueWrapper {
    fn write_redis_args<W>(&self, out: &mut W)
    where
    W: ?Sized + RedisWrite,
    {
    let json_str = serde_json::to_string(&self).unwrap();

    out.write_arg(json_str.as_bytes());
    }
    }

    impl FromRedisValue for ValueWrapper {
    fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult<Self> {
    let json_str: Option<String> = from_redis_value(v)?;

    match json_str {
    Some(json_str) => {
    let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| {
    RedisError::from(std::io::Error::new(
    std::io::ErrorKind::InvalidData,
    error.to_string(),
    ))
    })?;

    Ok(Self(json))
    }
    None => Ok(Self(Value::Null)),
    }
    }
    }

    impl ToRedisArgs for CacheEntry {
    fn write_redis_args<W>(&self, out: &mut W)
    where
    W: ?Sized + RedisWrite,
    {
    let json = serde_json::json!(
    {
    "key": self.key,
    "value": self.value
    }
    );

    let json_str = json.to_string();

    out.write_arg(json_str.as_bytes());
    }
    }

    impl FromRedisValue for CacheEntry {
    fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult<Self> {
    let json_str: Option<String> = from_redis_value(v)?;

    match json_str {
    Some(json_str) => {
    let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| {
    RedisError::from(std::io::Error::new(
    std::io::ErrorKind::InvalidData,
    error.to_string(),
    ))
    })?;

    let key = json.get("key").ok_or(RedisError::from(std::io::Error::new(
    std::io::ErrorKind::InvalidData,
    "key not found",
    )))?;
    let key = key.as_str().ok_or(RedisError::from(std::io::Error::new(
    std::io::ErrorKind::InvalidData,
    "key not found",
    )))?;
    let key = key.to_string();

    let value = json
    .get("value")
    .ok_or(RedisError::from(std::io::Error::new(
    std::io::ErrorKind::InvalidData,
    "value not found",
    )))?;
    let value = value.clone();

    Ok(Self {
    key,
    value: Some(ValueWrapper(value)),
    })
    }
    None => Ok(Self {
    key: Default::default(),
    value: None,
    }),
    }
    }
    }

    pub struct Cache {
    pool: Pool,
    }

    impl Cache {
    pub fn new(configuration: &CacheConfig) -> Result<Self, Error> {
    let redis = deadpool_redis::Config {
    url: Some(configuration.url().expose_secret().into()),
    connection: None,
    pool: Some(PoolConfig {
    max_size: configuration.pool_size(),
    timeouts: configuration.timeouts(),
    ..Default::default()
    }),
    };
    let pool = redis.create_pool(Some(Runtime::Tokio1)).map_err(|error| {
    Error::cache_generic(
    "There was an error with the configuration",
    None,
    None,
    error.into(),
    )
    })?;

    Ok(Self { pool })
    }

    pub async fn connection(&self) -> Result<Connection, Error> {
    let connection = self.pool.get().await.map_err(|error| {
    Error::cache_generic(
    "There was an error with the connection",
    None,
    None,
    error.into(),
    )
    })?;

    Ok(connection)
    }
    }

    #[async_trait]
    impl CacheExt for Cache {
    #[tracing::instrument(name = "cache::get_or_insert_with", skip(self, key, f))]
    async fn get_or_insert_with<F, Fut, T>(
    &self,
    key: &str,
    f: F,
    expire: Option<u64>,
    ) -> Result<T, Error>
    where
    T: for<'de> Deserialize<'de> + Serialize + Send + Clone,
    F: FnOnce() -> Fut + Send,
    Fut: Future<Output = Result<T, Error>> + Send,
    {
    match self.get(key).await? {
    Some(entry) => {
    tracing::debug!("Cache hit for key: {}", key);
    Ok(entry)
    }
    None => {
    let value = f().await?;
    self.set(key, value.clone(), expire).await?;
    Ok(value)
    }
    }
    }

    #[tracing::instrument(name = "cache::get", skip(self, key))]
    async fn get<T>(&self, key: &str) -> Result<Option<T>, Error>
    where
    T: for<'de> Deserialize<'de>,
    {
    let entry: CacheEntry = self.connection().await?.get(key).await?;

    match entry.value {
    Some(_) => Ok(entry.get_as()),
    None => Ok(None),
    }
    }

    #[tracing::instrument(name = "cache::insert", skip(self, key, value, expire))]
    async fn set<T>(&self, key: &str, value: T, expire: Option<u64>) -> Result<Unit, Error>
    where
    T: Serialize + Send,
    {
    let entry = CacheEntry::new(key.to_string(), serde_json::to_value(value)?);

    self.connection()
    .await?
    .set_ex::<_, CacheEntry, Option<String>>(
    entry.key.clone(),
    entry.clone(),
    expire.unwrap_or(86400),
    )
    .await
    .map(|_| ())?;

    Ok(())
    }

    #[tracing::instrument(name = "cache::remove", skip(self, key))]
    async fn remove(&self, key: &str) -> Result<Unit, Error> {
    self.connection()
    .await?
    .del::<_, i64>(key)
    .await
    .map(|_| ())?;

    Ok(())
    }

    #[tracing::instrument(name = "cache::clear", skip(self))]
    async fn clear(&self) -> Result<Unit, Error> {
    cmd("FLUSHALL")
    .query_async(&mut *self.connection().await?)
    .await
    .map(|_: ()| ())?;

    Ok(())
    }
    }