diff --git a/src/invidious.cr b/src/invidious.cr index 94620a26..91f19d69 100644 --- a/src/invidious.cr +++ b/src/invidious.cr @@ -248,7 +248,7 @@ before_all do |env| # Invidious users only have SID if !env.request.cookies.has_key? "SSID" if email = Invidious::Database::SessionIDs.select_email(sid) - user = PG_DB.query_one("SELECT * FROM users WHERE email = $1", email, as: User) + user = Invidious::Database::Users.select!(email: email) csrf_token = generate_response(sid, { ":authorize_token", ":playlist_ajax", @@ -458,10 +458,10 @@ post "/watch_ajax" do |env| case action when "action_mark_watched" if !user.watched.includes? id - PG_DB.exec("UPDATE users SET watched = array_append(watched, $1) WHERE email = $2", id, user.email) + Invidious::Database::Users.mark_watched(user, id) end when "action_mark_unwatched" - PG_DB.exec("UPDATE users SET watched = array_remove(watched, $1) WHERE email = $2", id, user.email) + Invidious::Database::Users.mark_unwatched(user, id) else next error_json(400, "Unsupported action #{action}") end @@ -599,16 +599,15 @@ post "/subscription_ajax" do |env| # Sync subscriptions with YouTube subscribe_ajax(channel_id, action, env.request.headers) end - email = user.email case action when "action_create_subscription_to_channel" if !user.subscriptions.includes? channel_id get_channel(channel_id, PG_DB, false, false) - PG_DB.exec("UPDATE users SET feed_needs_update = true, subscriptions = array_append(subscriptions, $1) WHERE email = $2", channel_id, email) + Invidious::Database::Users.subscribe_channel(user, channel_id) end when "action_remove_subscriptions" - PG_DB.exec("UPDATE users SET feed_needs_update = true, subscriptions = array_remove(subscriptions, $1) WHERE email = $2", channel_id, email) + Invidious::Database::Users.unsubscribe_channel(user, channel_id) else next error_json(400, "Unsupported action #{action}") end @@ -1008,7 +1007,7 @@ post "/delete_account" do |env| end view_name = "subscriptions_#{sha256(user.email)}" - PG_DB.exec("DELETE FROM users * WHERE email = $1", user.email) + Invidious::Database::Users.delete(user) Invidious::Database::SessionIDs.delete(email: user.email) PG_DB.exec("DROP MATERIALIZED VIEW #{view_name}") @@ -1059,7 +1058,7 @@ post "/clear_watch_history" do |env| next error_template(400, ex) end - PG_DB.exec("UPDATE users SET watched = '{}' WHERE email = $1", user.email) + Invidious::Database::Users.clear_watch_history(user) env.redirect referer end diff --git a/src/invidious/database/users.cr b/src/invidious/database/users.cr new file mode 100644 index 00000000..aa3b9f85 --- /dev/null +++ b/src/invidious/database/users.cr @@ -0,0 +1,129 @@ +require "./base.cr" + +module Invidious::Database::Users + extend self + + # ------------------- + # Insert / delete + # ------------------- + + def insert(user : User, update_on_conflict : Bool = false) + user_array = user.to_a + user_array[4] = user_array[4].to_json # User preferences + + request = <<-SQL + INSERT INTO users + VALUES (#{arg_array(user_array)}) + SQL + + if update_on_conflict + request += <<-SQL + ON CONFLICT (email) DO UPDATE + SET updated = $1, subscriptions = $3 + SQL + end + + PG_DB.exec(request, args: user_array) + end + + def delete(user : User) + request = <<-SQL + DELETE FROM users * + WHERE email = $1 + SQL + + PG_DB.exec(request, user.email) + end + + # ------------------- + # Update (history) + # ------------------- + + def mark_watched(user : User, vid : String) + request = <<-SQL + UPDATE users + SET watched = array_append(watched, $1) + WHERE email = $2 + SQL + + PG_DB.exec(request, vid, user.email) + end + + def mark_unwatched(user : User, vid : String) + request = <<-SQL + UPDATE users + SET watched = array_remove(watched, $1) + WHERE email = $2 + SQL + + PG_DB.exec(request, vid, user.email) + end + + def clear_watch_history(user : User) + request = <<-SQL + UPDATE users + SET watched = '{}' + WHERE email = $1 + SQL + + PG_DB.exec(request, user.email) + end + + # ------------------- + # Update (channels) + # ------------------- + + def subscribe_channel(user : User, ucid : String) + request = <<-SQL + UPDATE users + SET feed_needs_update = true, + subscriptions = array_append(subscriptions,$1) + WHERE email = $2 + SQL + + PG_DB.exec(request, ucid, user.email) + end + + def unsubscribe_channel(user : User, ucid : String) + request = <<-SQL + UPDATE users + SET feed_needs_update = true, + subscriptions = array_remove(subscriptions, $1) + WHERE email = $2 + SQL + + PG_DB.exec(request, ucid, user.email) + end + + # ------------------- + # Select + # ------------------- + + def select(*, email : String) : User? + request = <<-SQL + SELECT * FROM users + WHERE email = $1 + SQL + + return PG_DB.query_one?(request, email, as: User) + end + + # Same as select, but can raise an exception + def select!(*, email : String) : User + request = <<-SQL + SELECT * FROM users + WHERE email = $1 + SQL + + return PG_DB.query_one(request, email, as: User) + end + + def select(*, token : String) : User? + request = <<-SQL + SELECT * FROM users + WHERE token = $1 + SQL + + return PG_DB.query_one?(request, token, as: User) + end +end diff --git a/src/invidious/helpers/handlers.cr b/src/invidious/helpers/handlers.cr index 0aa86e64..d52035c7 100644 --- a/src/invidious/helpers/handlers.cr +++ b/src/invidious/helpers/handlers.cr @@ -100,7 +100,7 @@ class AuthHandler < Kemal::Handler scopes, expire, signature = validate_request(token, session, env.request, HMAC_KEY, PG_DB, nil) if email = Invidious::Database::SessionIDs.select_email(session) - user = PG_DB.query_one("SELECT * FROM users WHERE email = $1", email, as: User) + user = Invidious::Database::Users.select!(email: email) end elsif sid = env.request.cookies["SID"]?.try &.value if sid.starts_with? "v1:" @@ -108,7 +108,7 @@ class AuthHandler < Kemal::Handler end if email = Invidious::Database::SessionIDs.select_email(sid) - user = PG_DB.query_one("SELECT * FROM users WHERE email = $1", email, as: User) + user = Invidious::Database::Users.select!(email: email) end scopes = [":*"] diff --git a/src/invidious/routes/api/v1/authenticated.cr b/src/invidious/routes/api/v1/authenticated.cr index c95007c2..d9b58ebf 100644 --- a/src/invidious/routes/api/v1/authenticated.cr +++ b/src/invidious/routes/api/v1/authenticated.cr @@ -94,7 +94,7 @@ module Invidious::Routes::API::V1::Authenticated if !user.subscriptions.includes? ucid get_channel(ucid, PG_DB, false, false) - PG_DB.exec("UPDATE users SET feed_needs_update = true, subscriptions = array_append(subscriptions,$1) WHERE email = $2", ucid, user.email) + Invidious::Database::Users.subscribe_channel(user, ucid) end # For Google accounts, access tokens don't have enough information to @@ -110,7 +110,7 @@ module Invidious::Routes::API::V1::Authenticated ucid = env.params.url["ucid"] - PG_DB.exec("UPDATE users SET feed_needs_update = true, subscriptions = array_remove(subscriptions, $1) WHERE email = $2", ucid, user.email) + Invidious::Database::Users.unsubscribe_channel(user, ucid) env.response.status_code = 204 end diff --git a/src/invidious/routes/feeds.cr b/src/invidious/routes/feeds.cr index 78e6bd40..4e7ec9ad 100644 --- a/src/invidious/routes/feeds.cr +++ b/src/invidious/routes/feeds.cr @@ -220,7 +220,7 @@ module Invidious::Routes::Feeds haltf env, status_code: 403 end - user = PG_DB.query_one?("SELECT * FROM users WHERE token = $1", token.strip, as: User) + user = Invidious::Database::Users.select(token: token.strip) if !user haltf env, status_code: 403 end diff --git a/src/invidious/routes/login.cr b/src/invidious/routes/login.cr index e70206cc..8f703464 100644 --- a/src/invidious/routes/login.cr +++ b/src/invidious/routes/login.cr @@ -327,7 +327,7 @@ module Invidious::Routes::Login return error_template(401, "Password is a required field") end - user = PG_DB.query_one?("SELECT * FROM users WHERE email = $1", email, as: User) + user = Invidious::Database::Users.select(email: email) if user if !user.password @@ -449,12 +449,7 @@ module Invidious::Routes::Login end end - user_array = user.to_a - user_array[4] = user_array[4].to_json # User preferences - - args = arg_array(user_array) - - PG_DB.exec("INSERT INTO users VALUES (#{args})", args: user_array) + Invidious::Database::Users.insert(user) Invidious::Database::SessionIDs.insert(sid, email) view_name = "subscriptions_#{sha256(user.email)}" diff --git a/src/invidious/routes/watch.cr b/src/invidious/routes/watch.cr index b24222ff..c1ec0bc6 100644 --- a/src/invidious/routes/watch.cr +++ b/src/invidious/routes/watch.cr @@ -76,7 +76,7 @@ module Invidious::Routes::Watch env.params.query.delete_all("iv_load_policy") if watched && !watched.includes? id - PG_DB.exec("UPDATE users SET watched = array_append(watched, $1) WHERE email = $2", id, user.as(User).email) + Invidious::Database::Users.mark_watched(user.as(User), id) end if notifications && notifications.includes? id diff --git a/src/invidious/users.cr b/src/invidious/users.cr index 3e9a9e68..933c451d 100644 --- a/src/invidious/users.cr +++ b/src/invidious/users.cr @@ -31,17 +31,12 @@ end def get_user(sid, headers, db, refresh = true) if email = Invidious::Database::SessionIDs.select_email(sid) - user = db.query_one("SELECT * FROM users WHERE email = $1", email, as: User) + user = Invidious::Database::Users.select!(email: email) if refresh && Time.utc - user.updated > 1.minute user, sid = fetch_user(sid, headers, db) - user_array = user.to_a - user_array[4] = user_array[4].to_json # User preferences - args = arg_array(user_array) - - db.exec("INSERT INTO users VALUES (#{args}) \ - ON CONFLICT (email) DO UPDATE SET updated = $1, subscriptions = $3", args: user_array) + Invidious::Database::Users.insert(user, update_on_conflict: true) Invidious::Database::SessionIDs.insert(sid, user.email, handle_conflicts: true) begin @@ -52,13 +47,8 @@ def get_user(sid, headers, db, refresh = true) end else user, sid = fetch_user(sid, headers, db) - user_array = user.to_a - user_array[4] = user_array[4].to_json # User preferences - args = arg_array(user.to_a) - - db.exec("INSERT INTO users VALUES (#{args}) \ - ON CONFLICT (email) DO UPDATE SET updated = $1, subscriptions = $3", args: user_array) + Invidious::Database::Users.insert(user, update_on_conflict: true) Invidious::Database::SessionIDs.insert(sid, user.email, handle_conflicts: true) begin