use std::{ops::Deref, path::Path};
use rusqlite::Connection;
use sql_support::{open_database, ConnExt};
use crate::error::{PushError, Result};
use super::{record::PushRecord, schema};
pub trait Storage: Sized {
fn open<P: AsRef<Path>>(path: P) -> Result<Self>;
fn get_record(&self, chid: &str) -> Result<Option<PushRecord>>;
fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>>;
fn put_record(&self, record: &PushRecord) -> Result<bool>;
fn delete_record(&self, chid: &str) -> Result<bool>;
fn delete_all_records(&self) -> Result<()>;
fn get_channel_list(&self) -> Result<Vec<String>>;
#[allow(dead_code)]
fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool>;
fn get_uaid(&self) -> Result<Option<String>>;
fn set_uaid(&self, uaid: &str) -> Result<()>;
fn get_auth(&self) -> Result<Option<String>>;
fn set_auth(&self, auth: &str) -> Result<()>;
fn get_registration_id(&self) -> Result<Option<String>>;
fn set_registration_id(&self, native_id: &str) -> Result<()>;
fn get_meta(&self, key: &str) -> Result<Option<String>>;
fn set_meta(&self, key: &str, value: &str) -> Result<()>;
}
pub struct PushDb {
pub db: Connection,
}
impl PushDb {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let initializer = schema::PushConnectionInitializer {};
let db = open_database::open_database(path, &initializer).map_err(|orig| {
PushError::StorageError(format!(
"Could not open database file {:?} - {}",
&path.as_os_str(),
orig,
))
})?;
Ok(Self { db })
}
#[cfg(test)]
pub fn open_in_memory() -> Result<Self> {
env_logger::try_init().ok();
let initializer = schema::PushConnectionInitializer {};
let db = open_database::open_memory_database(&initializer)?;
Ok(Self { db })
}
pub fn normalize_uuid(uuid: &str) -> String {
uuid.replace('-', "").to_lowercase()
}
}
impl Deref for PushDb {
type Target = Connection;
fn deref(&self) -> &Connection {
&self.db
}
}
impl ConnExt for PushDb {
fn conn(&self) -> &Connection {
&self.db
}
}
impl Storage for PushDb {
fn get_record(&self, chid: &str) -> Result<Option<PushRecord>> {
let query = format!(
"SELECT {common_cols}
FROM push_record WHERE channel_id = :chid",
common_cols = schema::COMMON_COLS,
);
self.try_query_row(
&query,
&[(":chid", &Self::normalize_uuid(chid))],
PushRecord::from_row,
false,
)
}
fn get_record_by_scope(&self, scope: &str) -> Result<Option<PushRecord>> {
let query = format!(
"SELECT {common_cols}
FROM push_record WHERE scope = :scope",
common_cols = schema::COMMON_COLS,
);
self.try_query_row(&query, &[(":scope", scope)], PushRecord::from_row, false)
}
fn put_record(&self, record: &PushRecord) -> Result<bool> {
log::debug!(
"adding push subscription for scope '{}', channel '{}', endpoint '{}'",
record.scope,
record.channel_id,
record.endpoint
);
let query = format!(
"INSERT OR REPLACE INTO push_record
({common_cols})
VALUES
(:channel_id, :endpoint, :scope, :key, :ctime, :app_server_key)",
common_cols = schema::COMMON_COLS,
);
let affected_rows = self.execute(
&query,
&[
(
":channel_id",
&Self::normalize_uuid(&record.channel_id) as &dyn rusqlite::ToSql,
),
(":endpoint", &record.endpoint),
(":scope", &record.scope),
(":key", &record.key),
(":ctime", &record.ctime),
(":app_server_key", &record.app_server_key),
],
)?;
Ok(affected_rows == 1)
}
fn delete_record(&self, chid: &str) -> Result<bool> {
log::debug!("deleting push subscription: {}", chid);
let affected_rows = self.execute(
"DELETE FROM push_record
WHERE channel_id = :chid",
&[(":chid", &Self::normalize_uuid(chid))],
)?;
Ok(affected_rows == 1)
}
fn delete_all_records(&self) -> Result<()> {
log::debug!("deleting all push subscriptions and some metadata");
self.execute("DELETE FROM push_record", [])?;
self.execute_batch(
"DELETE FROM meta_data WHERE key='uaid';
DELETE FROM meta_data WHERE key='auth';
",
)?;
Ok(())
}
fn get_channel_list(&self) -> Result<Vec<String>> {
self.query_rows_and_then(
"SELECT channel_id FROM push_record",
[],
|row| -> Result<String> { Ok(row.get(0)?) },
)
}
fn update_endpoint(&self, channel_id: &str, endpoint: &str) -> Result<bool> {
log::debug!("updating endpoint for '{}' to '{}'", channel_id, endpoint);
let affected_rows = self.execute(
"UPDATE push_record set endpoint = :endpoint
WHERE channel_id = :channel_id",
&[
(":endpoint", &endpoint as &dyn rusqlite::ToSql),
(":channel_id", &Self::normalize_uuid(channel_id)),
],
)?;
Ok(affected_rows == 1)
}
fn get_uaid(&self) -> Result<Option<String>> {
self.get_meta("uaid")
}
fn set_uaid(&self, uaid: &str) -> Result<()> {
self.set_meta("uaid", uaid)
}
fn get_auth(&self) -> Result<Option<String>> {
self.get_meta("auth")
}
fn set_auth(&self, auth: &str) -> Result<()> {
self.set_meta("auth", auth)
}
fn get_registration_id(&self) -> Result<Option<String>> {
self.get_meta("registration_id")
}
fn set_registration_id(&self, registration_id: &str) -> Result<()> {
self.set_meta("registration_id", registration_id)
}
fn get_meta(&self, key: &str) -> Result<Option<String>> {
self.try_query_one(
"SELECT value FROM meta_data where key = :key limit 1",
&[(":key", &key)],
true,
)
.map_err(PushError::StorageSqlError)
}
fn set_meta(&self, key: &str, value: &str) -> Result<()> {
let query = "INSERT or REPLACE into meta_data (key, value) values (:k, :v)";
self.execute_cached(query, &[(":k", &key), (":v", &value)])?;
Ok(())
}
#[cfg(not(test))]
fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
PushDb::open(path)
}
#[cfg(test)]
fn open<P: AsRef<Path>>(_path: P) -> Result<Self> {
PushDb::open_in_memory()
}
}
#[cfg(test)]
mod test {
use crate::error::Result;
use crate::internal::crypto::{Crypto, Cryptography};
use super::PushDb;
use crate::internal::crypto::get_random_bytes;
use crate::internal::storage::{db::Storage, record::PushRecord};
const DUMMY_UAID: &str = "abad1dea00000000aabbccdd00000000";
fn get_db() -> Result<PushDb> {
env_logger::try_init().ok();
PushDb::open_in_memory()
}
fn get_uuid() -> Result<String> {
Ok(get_random_bytes(16)?
.iter()
.map(|b| format!("{:02x}", b))
.collect::<Vec<String>>()
.join(""))
}
fn prec(chid: &str) -> PushRecord {
PushRecord::new(
chid,
&format!("https://example.com/update/{}", chid),
"https://example.com/",
Crypto::generate_key().expect("Couldn't generate_key"),
)
.unwrap()
}
#[test]
fn basic() -> Result<()> {
let db = get_db()?;
let chid = &get_uuid()?;
let rec = prec(chid);
assert!(db.get_record(chid)?.is_none());
db.put_record(&rec)?;
assert!(db.get_record(chid)?.is_some());
db.put_record(&rec)?;
assert_eq!(db.get_record(chid)?, Some(rec.clone()));
let mut rec2 = rec.clone();
rec2.endpoint = format!("https://example.com/update2/{}", chid);
db.put_record(&rec2)?;
let result = db.get_record(chid)?.unwrap();
assert_ne!(result, rec);
assert_eq!(result, rec2);
let result = db.get_record_by_scope("https://example.com/")?.unwrap();
assert_eq!(result, rec2);
Ok(())
}
#[test]
fn delete() -> Result<()> {
let db = get_db()?;
let chid = &get_uuid()?;
let rec = prec(chid);
assert!(db.put_record(&rec)?);
assert!(db.get_record(chid)?.is_some());
assert!(db.delete_record(chid)?);
assert!(db.get_record(chid)?.is_none());
Ok(())
}
#[test]
fn delete_all_records() -> Result<()> {
let db = get_db()?;
let chid = &get_uuid()?;
let rec = prec(chid);
let mut rec2 = rec.clone();
rec2.channel_id = get_uuid()?;
rec2.endpoint = format!("https://example.com/update/{}", &rec2.channel_id);
assert!(db.put_record(&rec)?);
assert!(db.put_record(&rec2)?);
assert!(db.get_record(&rec.channel_id)?.is_none());
assert!(db.get_record(&rec2.channel_id)?.is_some());
db.delete_all_records()?;
assert!(db.get_record(&rec.channel_id)?.is_none());
assert!(db.get_record(&rec.channel_id)?.is_none());
assert!(db.get_uaid()?.is_none());
assert!(db.get_auth()?.is_none());
Ok(())
}
#[test]
fn meta() -> Result<()> {
use super::Storage;
let db = get_db()?;
let no_rec = db.get_uaid()?;
assert_eq!(no_rec, None);
db.set_uaid(DUMMY_UAID)?;
db.set_meta("fruit", "apple")?;
db.set_meta("fruit", "banana")?;
assert_eq!(db.get_uaid()?, Some(DUMMY_UAID.to_owned()));
assert_eq!(db.get_meta("fruit")?, Some("banana".to_owned()));
Ok(())
}
#[test]
fn dash() -> Result<()> {
let db = get_db()?;
let chid = "deadbeef-0000-0000-0000-decafbad12345678";
let rec = prec(chid);
assert!(db.put_record(&rec)?);
assert!(db.get_record(chid)?.is_some());
assert!(db.delete_record(chid)?);
Ok(())
}
}