diff --git a/include/mod_invites.hrl b/include/mod_invites.hrl index 8555808e9..104fd61ae 100644 --- a/include/mod_invites.hrl +++ b/include/mod_invites.hrl @@ -4,6 +4,11 @@ -define(NS_INVITE_INVITE, <<"urn:xmpp:invite#invite">>). -define(NS_INVITE_CREATE_ACCOUNT, <<"urn:xmpp:invite#create-account">>). +-define(OVERUSE_LIMIT, 1000). + +-define(SPEEDY_GOAT_LEVELS, 2). +-define(SPEEDY_GOAT_SECONDS, 300). + -record(invite_token, {token :: binary(), inviter :: {binary(), binary()}, %% A non-empty value if `invitee` indicates the invite has been used. diff --git a/rebar.config b/rebar.config index cd2212f99..a89443ce7 100644 --- a/rebar.config +++ b/rebar.config @@ -300,7 +300,8 @@ {copy, "test/ejabberd_SUITE_data/ca.pem", "conf/"}, {copy, "test/ejabberd_SUITE_data/cert.pem", "conf/"}]}]}]}, {translations, [{deps, [{ejabberd_po, ".*", {git, "https://github.com/processone/ejabberd-po", {branch, "main"}}}]}]}, - {test, [{erl_opts, [nowarn_export_all]}]}]}. + {test, [{erl_opts, [nowarn_export_all]}, + {deps, [meck]}]}]}. {alias, [{relive, [{shell, "--apps ejabberd \ --config rel/relive.config \ diff --git a/sql/lite.new.sql b/sql/lite.new.sql index d39590281..5cc9f25e2 100644 --- a/sql/lite.new.sql +++ b/sql/lite.new.sql @@ -503,3 +503,4 @@ CREATE TABLE invite_token ( PRIMARY KEY (token) ); CREATE INDEX i_invite_token_username_server_host ON invite_token(username, server_host); +CREATE INDEX i_invite_token_invitee ON invite_token(invitee); diff --git a/sql/lite.sql b/sql/lite.sql index a08d6242b..b602765ce 100644 --- a/sql/lite.sql +++ b/sql/lite.sql @@ -470,3 +470,4 @@ CREATE TABLE invite_token ( PRIMARY KEY (token) ); CREATE INDEX i_invite_token_username ON invite_token(username); +CREATE INDEX i_invite_token_invitee ON invite_token(invitee); diff --git a/sql/mysql.new.sql b/sql/mysql.new.sql index 407f8cd2b..d978288a6 100644 --- a/sql/mysql.new.sql +++ b/sql/mysql.new.sql @@ -522,3 +522,4 @@ CREATE TABLE invite_token ( ) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; CREATE INDEX i_invite_token_username USING BTREE ON invite_token(username(191), server_host(191)); +CREATE INDEX i_invite_token_invitee USING BTREE ON invite_token(invitee(191)); diff --git a/sql/mysql.sql b/sql/mysql.sql index 79779ef2c..4820d06ae 100644 --- a/sql/mysql.sql +++ b/sql/mysql.sql @@ -487,3 +487,4 @@ CREATE TABLE invite_token ( ) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; CREATE INDEX i_invite_token_username USING BTREE ON invite_token(username(191)); +CREATE INDEX i_invite_token_invitee USING BTREE ON invite_token(invitee(191)); diff --git a/sql/pg.new.sql b/sql/pg.new.sql index 999092919..f56eaff5f 100644 --- a/sql/pg.new.sql +++ b/sql/pg.new.sql @@ -676,3 +676,4 @@ CREATE TABLE invite_token ( PRIMARY KEY (token) ); CREATE INDEX i_invite_token_username_server_host ON invite_token USING btree (username, server_host); +CREATE INDEX i_invite_token_invitee ON invite_token USING btree (invitee); diff --git a/sql/pg.sql b/sql/pg.sql index 182e6636b..739ec3e5d 100644 --- a/sql/pg.sql +++ b/sql/pg.sql @@ -491,3 +491,4 @@ CREATE TABLE invite_token ( PRIMARY KEY (token) ); CREATE INDEX i_invite_token_username ON invite_token USING btree (username); +CREATE INDEX i_invite_token_invitee ON invite_token USING btree (invitee); diff --git a/src/mod_invites.erl b/src/mod_invites.erl index 86f67e676..93444b44f 100644 --- a/src/mod_invites.erl +++ b/src/mod_invites.erl @@ -46,15 +46,16 @@ -export([cleanup_expired/0, expire_tokens/2, generate_invite/1, generate_invite/2, list_invites/1]). %% helpers --export([create_account_allowed/2, get_invite/2, get_max_invites/2, is_create_allowed/2, - is_expired/1, is_reserved/3, is_token_valid/2, roster_add/2, send_presence/3, - set_invitee/3, set_invitee/5, token_uri/1, xdata_field/3]). +-export([create_account_allowed/2, get_invite/2, get_invites_tree_t/2, get_max_invites/2, + is_create_allowed/2, is_expired/1, is_reserved/3, is_token_valid/2, roster_add/2, + send_presence/3, set_invitee/3, set_invitee/5, token_uri/1, transaction/2, xdata_field/3]). %% ejabberd_http -export([process/2]). -ifdef(TEST). --export([create_roster_invite/2, create_account_invite/4, gen_invite/1, gen_invite/2, get_invites/2, is_token_valid/3]). +-export([create_roster_invite/2, create_account_invite/4, find_invites_tree_root_t/4, gen_invite/1, + gen_invite/2, get_invites/2, get_invites_tree_as_root_t/2, is_token_valid/3]). -endif. -include("logger.hrl"). @@ -66,12 +67,15 @@ -include("translate.hrl"). -type invite_token() :: #invite_token{}. +-export_type([invite_token/0]). -callback cleanup_expired(Host :: binary()) -> non_neg_integer(). -callback create_invite_t(Invite :: invite_token()) -> invite_token(). -callback expire_tokens(User :: binary(), Server :: binary()) -> non_neg_integer(). -callback get_invite(Host :: binary(), Token :: binary()) -> invite_token() | {error, not_found}. +-callback get_invite_by_invitee_t(Host :: binary(), InviteeJid :: binary()) -> + invite_token() | {error, not_found}. -callback get_invites_t(Host :: binary(), Inviter :: {User :: binary(), Host :: binary()}) -> [invite_token()]. -callback init(Host :: binary(), gen_mod:opts()) -> any(). @@ -807,7 +811,6 @@ create_invite(Type, Host, Inviter, AccountName) -> create_invite_t(Type, Host, Inviter, AccountName) -> try invite_token_t(Type, Host, Inviter, AccountName) of Invite -> - ?DEBUG("Creating invite: ~p", [Invite]), db_call(Host, create_invite_t, [Invite]) catch _:({error, _Reason} = Error) -> @@ -900,6 +903,89 @@ get_max_invites(User, Server) -> MaxInvites end. +check_overuse_t(roster_only, {User, Host}) -> + NumInvites = length(get_invites_t(Host, {User, Host})), + case NumInvites >= ?OVERUSE_LIMIT of + true -> + {error, num_invites_exceeded}; + false -> + ok + end; +check_overuse_t(_Type, {User, Host}) -> + NumInvites = length(get_invites_tree_t(Host, {User, Host})), + case NumInvites >= ?OVERUSE_LIMIT of + true -> + {error, num_invites_exceeded}; + false -> + ok + end. + +get_invites_tree_t(Host, Inviter) -> + Now = calendar:datetime_to_gregorian_seconds( + calendar:now_to_datetime( + erlang:timestamp())), + Root = find_invites_tree_root_t(Now, Host, Inviter, 0), + get_invites_tree_as_root_t(Host, Root). + +find_invites_tree_root_t(Now, Host, Invitee, Lvl) -> + case get_invite_by_invitee_t(Host, Invitee) of + #invite_token{inviter = Inviter, created_at = CreatedAt} -> + maybe_block_speedy_goat(Now, CreatedAt, Lvl), + find_invites_tree_root_t(Now, Host, Inviter, Lvl + 1); + {error, not_found} -> + Invitee + end. + +-spec get_invite_by_invitee_t(binary(), {binary(), binary()}) -> + invite_token() | {error, not_found}. +get_invite_by_invitee_t(Host, {User, Server}) -> + InviteeJid = + jid:encode( + jid:make(User, Server)), + db_call(Host, get_invite_by_invitee_t, [Host, InviteeJid]). + +maybe_block_speedy_goat(Now, CreatedAt, Lvl) when Lvl == ?SPEEDY_GOAT_LEVELS -> + Then = calendar:datetime_to_gregorian_seconds(CreatedAt), + if Now - Then < ?SPEEDY_GOAT_SECONDS -> + throw(speedy_goat); + true -> + ok + end; +maybe_block_speedy_goat(_, _, _) -> + ok. + +-spec get_invites_tree_as_root_t(binary(), {binary(), binary()}) -> [invite_token()]. +get_invites_tree_as_root_t(Host, Inviter) -> + Invites = get_invites_t(Host, Inviter), + get_invites_tree_as_root_t(Host, Inviter, Invites, []). + +get_invites_tree_as_root_t(_Host, _Inviter, [], Acc) -> + Acc; +get_invites_tree_as_root_t(Host, + Inviter, + [#invite_token{type = roster_only, account_name = <<>>} | Invites], + Acc) -> + get_invites_tree_as_root_t(Host, Inviter, Invites, Acc); +get_invites_tree_as_root_t(Host, + Inviter, + [#invite_token{invitee = <<>>} = Invite | Invites], + Acc) -> + get_invites_tree_as_root_t(Host, Inviter, Invites, [Invite | Acc]); +get_invites_tree_as_root_t(Host, + Inviter, + [#invite_token{invitee = InviteeJID} = Invite | Invites], + Acc) -> + case jid:decode(InviteeJID) of + #jid{luser = Invitee, lserver = Host} -> + get_invites_tree_as_root_t(Host, + Inviter, + Invites, + [Invite | Acc] + ++ get_invites_tree_as_root_t(Host, {Invitee, Host})); + _Nomatch -> + get_invites_tree_as_root_t(Host, Inviter, Invites, [Invite | Acc]) + end. + maybe_throw({error, _} = Error) -> throw(Error); maybe_throw(Good) -> @@ -907,6 +993,7 @@ maybe_throw(Good) -> invite_token_t(Type, Host, Inviter, AccountName0) -> maybe_throw(check_max_invites_t(Type, Inviter)), + maybe_throw(check_overuse_t(Type, Inviter)), Token = p1_rand:get_alphanum_string(?INVITE_TOKEN_LENGTH_DEFAULT), AccountName = maybe_throw(check_account_name(jid:nodeprep(AccountName0), Host)), set_token_expires(#invite_token{token = Token, diff --git a/src/mod_invites_mnesia.erl b/src/mod_invites_mnesia.erl index 088a4843a..07017e18d 100644 --- a/src/mod_invites_mnesia.erl +++ b/src/mod_invites_mnesia.erl @@ -27,11 +27,13 @@ -behaviour(mod_invites). --export([cleanup_expired/1, create_invite_t/1, expire_tokens/2, get_invite/2, get_invites_t/2, init/2, - is_reserved/3, is_token_valid/3, list_invites/1, remove_user/2, - set_invitee/5, transaction/2]). +-export([cleanup_expired/1, create_invite_t/1, expire_tokens/2, get_invite/2, get_invites_t/2, + get_invite_by_invitee_t/2, init/2, is_reserved/3, is_token_valid/3, list_invites/1, + remove_user/2, set_invitee/5, transaction/2]). -include("mod_invites.hrl"). +-include("logger.hrl"). +-include_lib("xmpp/include/xmpp.hrl"). %% @format-begin @@ -70,6 +72,14 @@ get_invite(_Host, Token) -> {error, not_found} end. +get_invite_by_invitee_t(_Host, InviteeJid) -> + case mnesia:index_read(invite_token, InviteeJid, #invite_token.invitee) of + [#invite_token{type = Type} = Invite] when Type /= roster_only -> + Invite; + _ -> + {error, not_found} + end. + get_invites_t(_Host, Inviter) -> mnesia:index_read(invite_token, Inviter, #invite_token.inviter). @@ -78,7 +88,7 @@ init(_Host, _Opts) -> invite_token, [{disc_copies, [node()]}, {attributes, record_info(fields, invite_token)}, - {index, [inviter]}]). + {index, [inviter, invitee]}]). is_reserved(_Host, Token, User) -> lists:filter(fun(T) -> diff --git a/src/mod_invites_register.erl b/src/mod_invites_register.erl index 395f2b9cd..78f775f35 100644 --- a/src/mod_invites_register.erl +++ b/src/mod_invites_register.erl @@ -196,7 +196,19 @@ create_account_allowed(#invite_token{type = roster_only} = Invite) -> #invite_token{inviter = {User, Host}} = Invite, case mod_invites:is_create_allowed(User, Host) of true -> - ok; + NumInvites = + length( + mod_invites:transaction( + Host, + fun() -> + mod_invites:get_invites_tree_t(Host, {User, Host}) + end)), + case NumInvites >= ?OVERUSE_LIMIT of + false -> + ok; + true -> + {error, not_allowed} + end; false -> {error, not_allowed} end; diff --git a/src/mod_invites_sql.erl b/src/mod_invites_sql.erl index 7e7ef65a3..bd1aed8e0 100644 --- a/src/mod_invites_sql.erl +++ b/src/mod_invites_sql.erl @@ -27,9 +27,9 @@ -behaviour(mod_invites). --export([cleanup_expired/1, create_invite_t/1, expire_tokens/2, get_invite/2, get_invites_t/2, init/2, - is_reserved/3, is_token_valid/3, list_invites/1, remove_user/2, - set_invitee/5, transaction/2]). +-export([cleanup_expired/1, create_invite_t/1, expire_tokens/2, get_invite/2, + get_invite_by_invitee_t/2, get_invites_t/2, init/2, is_reserved/3, is_token_valid/3, + list_invites/1, remove_user/2, set_invitee/5, transaction/2]). -export([sql_schemas/0]). @@ -46,7 +46,31 @@ init(Host, _Opts) -> ejabberd_sql_schema:update_schema(Host, ?MODULE, sql_schemas()). sql_schemas() -> - [#sql_schema{version = 1, + [#sql_schema{version = 2, + tables = + [#sql_table{name = <<"invite_token">>, + columns = + [#sql_column{name = <<"token">>, type = text}, + #sql_column{name = <<"username">>, type = text}, + #sql_column{name = <<"server_host">>, type = text}, + #sql_column{name = <<"invitee">>, + type = {text, 191}, + default = true}, + #sql_column{name = <<"created_at">>, + type = timestamp, + default = true}, + #sql_column{name = <<"expires">>, + type = timestamp, + default = true}, + #sql_column{name = <<"type">>, type = {char, 1}}, + #sql_column{name = <<"account_name">>, type = text}], + indices = + [#sql_index{columns = [<<"token">>], unique = true}, + #sql_index{columns = + [<<"username">>, <<"server_host">>]}, + #sql_index{columns = [<<"invitee">>]}]}], + update = [{create_index, <<"invite_token">>, [<<"invitee">>]}]}, + #sql_schema{version = 1, tables = [#sql_table{name = <<"invite_token">>, columns = @@ -126,6 +150,26 @@ get_invite(Host, Token) -> {error, not_found} end. +-spec get_invite_by_invitee_t(binary(), binary()) -> + mod_invites:invite_token() | {error, not_found}. +get_invite_by_invitee_t(Host, InviteeJid) -> + case ejabberd_sql:sql_query(Host, + ?SQL("SELECT @(token)s, @(username)s, @(invitee)s, @(type)s, " + "@(account_name)s, @(expires)t, @(created_at)t FROM " + "invite_token WHERE invitee = %(InviteeJid)s AND %(Host)H")) + of + {selected, [{Token, User, Invitee, Type, AccountName, Expires, CreatedAt}]} -> + #invite_token{token = Token, + inviter = {User, Host}, + invitee = Invitee, + type = dec_type(Type), + account_name = AccountName, + expires = Expires, + created_at = CreatedAt}; + {selected, []} -> + {error, not_found} + end. + get_invites_t(Host, {User, _Host}) -> {selected, Invites} = ejabberd_sql:sql_query_t(?SQL("SELECT @(token)s, @(invitee)s, @(type)s, @(account_name)s, " diff --git a/test/invites_tests.erl b/test/invites_tests.erl index 1b1437578..79d1af858 100644 --- a/test/invites_tests.erl +++ b/test/invites_tests.erl @@ -32,6 +32,8 @@ -include("mod_invites.hrl"). -include("mod_roster.hrl"). +-include_lib("eunit/include/eunit.hrl"). + %% killme -record(ejabberd_module, {module_host = {undefined, <<"">>} :: {atom(), binary()}, @@ -41,6 +43,80 @@ %% @format-begin +find_invites_tree_root_t_test_() -> + {setup, + fun() -> + meck:new(db, [non_strict]), + meck:expect(db, + get_invite_by_invitee_t, + fun (_, <<"4@host">>) -> + #invite_token{inviter = {<<"3">>, <<"host">>}}; + (_, <<"3@host">>) -> + #invite_token{inviter = {<<"2">>, <<"host">>}}; + (_, <<"2@host">>) -> + #invite_token{inviter = {<<"1">>, <<"host">>}}; + (_, _) -> + {error, not_found} + end), + meck:new(gen_mod, [passthrough]), + meck:expect(gen_mod, db_mod, 2, db), + meck:new(calendar, [unstick, passthrough]), + meck:expect(calendar, now_to_datetime, 1, then), + meck:expect(calendar, datetime_to_gregorian_seconds, fun(then) -> 1 end), + [db, gen_mod, calendar] + end, + fun meck:unload/1, + fun(_) -> + [%% lvl not reached + ?_assertMatch({<<"1">>, <<"host">>}, + mod_invites:find_invites_tree_root_t(2, host, {<<"3">>, <<"host">>}, 0)), + %% lvl reached + ?_assertThrow(speedy_goat, mod_invites:find_invites_tree_root_t(2, host, {<<"4">>, <<"host">>}, 0)), + %% lvl reached but later + ?_assertMatch({<<"1">>, <<"host">>}, + mod_invites:find_invites_tree_root_t(?SPEEDY_GOAT_SECONDS + 1, + host, + {<<"4">>, <<"host">>}, + 0)), + ?_assert(meck:validate(db))] + end}. + +get_invites_tree_as_root_t_test_() -> + {setup, + fun() -> + meck:new(db, [non_strict]), + meck:expect(db, + get_invites_t, + fun (_, {<<"1">>, _}) -> + [#invite_token{invitee = <<"2@host">>, type = account_only}, + #invite_token{invitee = <<"rosterinvite@forcecrash">>}]; + (_, {<<"2">>, _}) -> + [#invite_token{invitee = <<"3@host">>, type = account_only}, + #invite_token{invitee = <<"4@host">>, type = account_only}]; + (_, {<<"3">>, _}) -> + [#invite_token{invitee = <<"5@host">>, type = account_subscription}, + #invite_token{invitee = <<"6@host">>, account_name = <<"6">>}, + #invite_token{type = account_only}]; + (_, {_, <<"host">>}) -> + [] + end), + meck:new(gen_mod, [passthrough]), + meck:expect(gen_mod, db_mod, 2, db), + meck:expect(jid, + decode, + fun(Str) -> + [LUser, LServer] = + [list_to_binary(T) || T <- string:tokens(binary_to_list(Str), "@")], + #jid{luser = LUser, lserver = LServer} + end), + [db, gen_mod, jid] + end, + fun meck:unload/1, + fun(_) -> + [?_assertMatch(6, length(mod_invites:get_invites_tree_as_root_t(<<"host">>, {<<"1">>, <<"host">>})))] + end}. + + %%%=================================================================== %%% API %%%===================================================================