Fix jobs being added to batch after they might already execute (#35496)

This commit is contained in:
Eugen Rochko 2025-07-28 10:20:12 +02:00 committed by GitHub
parent a57a9505d4
commit 018e5e303f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 52 additions and 12 deletions

View File

@ -19,17 +19,22 @@ class WorkerBatch
redis.hset(key, { 'async_refresh_key' => async_refresh_key, 'threshold' => threshold }) redis.hset(key, { 'async_refresh_key' => async_refresh_key, 'threshold' => threshold })
end end
def within
raise NoBlockGivenError unless block_given?
begin
Thread.current[:batch] = self
yield
ensure
Thread.current[:batch] = nil
end
end
# Add jobs to the batch. Usually when the batch is created. # Add jobs to the batch. Usually when the batch is created.
# @param [Array<String>] jids # @param [Array<String>] jids
def add_jobs(jids) def add_jobs(jids)
if jids.blank? if jids.blank?
async_refresh_key = redis.hget(key, 'async_refresh_key') finish!
if async_refresh_key.present?
async_refresh = AsyncRefresh.new(async_refresh_key)
async_refresh.finish!
end
return return
end end
@ -55,8 +60,23 @@ class WorkerBatch
if async_refresh_key.present? if async_refresh_key.present?
async_refresh = AsyncRefresh.new(async_refresh_key) async_refresh = AsyncRefresh.new(async_refresh_key)
async_refresh.increment_result_count(by: 1) async_refresh.increment_result_count(by: 1)
async_refresh.finish! if pending.zero? || processed >= threshold.to_f * (processed + pending)
end end
if pending.zero? || processed >= (threshold || 1.0).to_f * (processed + pending)
async_refresh&.finish!
cleanup
end
end
def finish!
async_refresh_key = redis.hget(key, 'async_refresh_key')
if async_refresh_key.present?
async_refresh = AsyncRefresh.new(async_refresh_key)
async_refresh.finish!
end
cleanup
end end
# Get pending jobs. # Get pending jobs.
@ -76,4 +96,8 @@ class WorkerBatch
def key(suffix = nil) def key(suffix = nil)
"worker_batch:#{@id}#{":#{suffix}" if suffix}" "worker_batch:#{@id}#{":#{suffix}" if suffix}"
end end
def cleanup
redis.del(key, key('jobs'))
end
end end

View File

@ -17,7 +17,12 @@ class ActivityPub::FetchRepliesService < BaseService
batch = WorkerBatch.new batch = WorkerBatch.new
batch.connect(async_refresh_key) if async_refresh_key.present? batch.connect(async_refresh_key) if async_refresh_key.present?
batch.add_jobs(FetchReplyWorker.push_bulk(@items) { |reply_uri| [reply_uri, { 'request_id' => request_id, 'batch_id' => batch.id }] }) batch.finish! if @items.empty?
batch.within do
FetchReplyWorker.push_bulk(@items) do |reply_uri|
[reply_uri, { 'request_id' => request_id, 'batch_id' => batch.id }]
end
end
[@items, n_pages] [@items, n_pages]
end end

View File

@ -1,6 +1,7 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative '../../lib/mastodon/sidekiq_middleware' require_relative '../../lib/mastodon/sidekiq_middleware'
require_relative '../../lib/mastodon/worker_batch_middleware'
Sidekiq.configure_server do |config| Sidekiq.configure_server do |config|
config.redis = REDIS_CONFIGURATION.sidekiq config.redis = REDIS_CONFIGURATION.sidekiq
@ -72,14 +73,12 @@ Sidekiq.configure_server do |config|
config.server_middleware do |chain| config.server_middleware do |chain|
chain.add Mastodon::SidekiqMiddleware chain.add Mastodon::SidekiqMiddleware
end
config.server_middleware do |chain|
chain.add SidekiqUniqueJobs::Middleware::Server chain.add SidekiqUniqueJobs::Middleware::Server
end end
config.client_middleware do |chain| config.client_middleware do |chain|
chain.add SidekiqUniqueJobs::Middleware::Client chain.add SidekiqUniqueJobs::Middleware::Client
chain.add Mastodon::WorkerBatchMiddleware
end end
config.on(:startup) do config.on(:startup) do
@ -105,6 +104,7 @@ Sidekiq.configure_client do |config|
config.client_middleware do |chain| config.client_middleware do |chain|
chain.add SidekiqUniqueJobs::Middleware::Client chain.add SidekiqUniqueJobs::Middleware::Client
chain.add Mastodon::WorkerBatchMiddleware
end end
config.logger.level = Logger.const_get(ENV.fetch('RAILS_LOG_LEVEL', 'info').upcase.to_s) config.logger.level = Logger.const_get(ENV.fetch('RAILS_LOG_LEVEL', 'info').upcase.to_s)

View File

@ -0,0 +1,11 @@
# frozen_string_literal: true
class Mastodon::WorkerBatchMiddleware
def call(_worker, msg, _queue, _redis_pool = nil)
if (batch = Thread.current[:batch])
batch.add_jobs([msg['jid']])
end
yield
end
end