Add cancellation token support for beatmap difficulty calculation.

This commit is contained in:
Tollii
2021-11-06 00:19:48 +01:00
parent f0caa10066
commit eb7d04bc77
4 changed files with 18 additions and 13 deletions

View File

@ -147,13 +147,13 @@ namespace osu.Game.Beatmaps
if (CheckExists(lookup, out var existing)) if (CheckExists(lookup, out var existing))
return existing; return existing;
return computeDifficulty(lookup); return computeDifficulty(lookup, token);
}, token, TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously, updateScheduler); }, token, TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously, updateScheduler);
} }
public Task<List<TimedDifficultyAttributes>> GetTimedDifficultyAttributesAsync(WorkingBeatmap beatmap, Ruleset ruleset, Mod[] mods, CancellationToken token = default) public Task<List<TimedDifficultyAttributes>> GetTimedDifficultyAttributesAsync(WorkingBeatmap beatmap, Ruleset ruleset, Mod[] mods, CancellationToken token = default)
{ {
return Task.Factory.StartNew(() => ruleset.CreateDifficultyCalculator(beatmap).CalculateTimed(mods), return Task.Factory.StartNew(() => ruleset.CreateDifficultyCalculator(beatmap).CalculateTimed(mods, token),
token, token,
TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously, TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously,
updateScheduler); updateScheduler);
@ -270,8 +270,9 @@ namespace osu.Game.Beatmaps
/// Computes the difficulty defined by a <see cref="DifficultyCacheLookup"/> key, and stores it to the timed cache. /// Computes the difficulty defined by a <see cref="DifficultyCacheLookup"/> key, and stores it to the timed cache.
/// </summary> /// </summary>
/// <param name="key">The <see cref="DifficultyCacheLookup"/> that defines the computation parameters.</param> /// <param name="key">The <see cref="DifficultyCacheLookup"/> that defines the computation parameters.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The <see cref="StarDifficulty"/>.</returns> /// <returns>The <see cref="StarDifficulty"/>.</returns>
private StarDifficulty computeDifficulty(in DifficultyCacheLookup key) private StarDifficulty computeDifficulty(in DifficultyCacheLookup key, CancellationToken cancellationToken = default)
{ {
// In the case that the user hasn't given us a ruleset, use the beatmap's default ruleset. // In the case that the user hasn't given us a ruleset, use the beatmap's default ruleset.
var beatmapInfo = key.BeatmapInfo; var beatmapInfo = key.BeatmapInfo;
@ -283,7 +284,7 @@ namespace osu.Game.Beatmaps
Debug.Assert(ruleset != null); Debug.Assert(ruleset != null);
var calculator = ruleset.CreateDifficultyCalculator(beatmapManager.GetWorkingBeatmap(key.BeatmapInfo)); var calculator = ruleset.CreateDifficultyCalculator(beatmapManager.GetWorkingBeatmap(key.BeatmapInfo));
var attributes = calculator.Calculate(key.OrderedMods); var attributes = calculator.Calculate(key.OrderedMods, cancellationToken);
return new StarDifficulty(attributes); return new StarDifficulty(attributes);
} }

View File

@ -408,7 +408,7 @@ namespace osu.Game.Beatmaps
beatmap.BeatmapInfo.Ruleset = ruleset; beatmap.BeatmapInfo.Ruleset = ruleset;
// TODO: this should be done in a better place once we actually need to dynamically update it. // TODO: this should be done in a better place once we actually need to dynamically update it.
beatmap.BeatmapInfo.StarDifficulty = ruleset?.CreateInstance().CreateDifficultyCalculator(new DummyConversionBeatmap(beatmap)).Calculate().StarRating ?? 0; beatmap.BeatmapInfo.StarDifficulty = ruleset?.CreateInstance().CreateDifficultyCalculator(new DummyConversionBeatmap(beatmap)).Calculate(null).StarRating ?? 0;
beatmap.BeatmapInfo.Length = calculateLength(beatmap); beatmap.BeatmapInfo.Length = calculateLength(beatmap);
beatmap.BeatmapInfo.BPM = 60000 / beatmap.GetMostCommonBeatLength(); beatmap.BeatmapInfo.BPM = 60000 / beatmap.GetMostCommonBeatLength();

View File

@ -4,6 +4,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading;
using osu.Framework.Audio.Track; using osu.Framework.Audio.Track;
using osu.Framework.Extensions.IEnumerableExtensions; using osu.Framework.Extensions.IEnumerableExtensions;
using osu.Game.Beatmaps; using osu.Game.Beatmaps;
@ -39,10 +40,11 @@ namespace osu.Game.Rulesets.Difficulty
/// Calculates the difficulty of the beatmap using a specific mod combination. /// Calculates the difficulty of the beatmap using a specific mod combination.
/// </summary> /// </summary>
/// <param name="mods">The mods that should be applied to the beatmap.</param> /// <param name="mods">The mods that should be applied to the beatmap.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A structure describing the difficulty of the beatmap.</returns> /// <returns>A structure describing the difficulty of the beatmap.</returns>
public DifficultyAttributes Calculate(params Mod[] mods) public DifficultyAttributes Calculate(IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{ {
preProcess(mods); preProcess(mods, cancellationToken);
var skills = CreateSkills(Beatmap, playableMods, clockRate); var skills = CreateSkills(Beatmap, playableMods, clockRate);
@ -62,10 +64,11 @@ namespace osu.Game.Rulesets.Difficulty
/// Calculates the difficulty of the beatmap and returns a set of <see cref="TimedDifficultyAttributes"/> representing the difficulty at every relevant time value in the beatmap. /// Calculates the difficulty of the beatmap and returns a set of <see cref="TimedDifficultyAttributes"/> representing the difficulty at every relevant time value in the beatmap.
/// </summary> /// </summary>
/// <param name="mods">The mods that should be applied to the beatmap.</param> /// <param name="mods">The mods that should be applied to the beatmap.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The set of <see cref="TimedDifficultyAttributes"/>.</returns> /// <returns>The set of <see cref="TimedDifficultyAttributes"/>.</returns>
public List<TimedDifficultyAttributes> CalculateTimed(params Mod[] mods) public List<TimedDifficultyAttributes> CalculateTimed(IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{ {
preProcess(mods); preProcess(mods, cancellationToken);
var attribs = new List<TimedDifficultyAttributes>(); var attribs = new List<TimedDifficultyAttributes>();
@ -99,7 +102,7 @@ namespace osu.Game.Rulesets.Difficulty
if (combination is MultiMod multi) if (combination is MultiMod multi)
yield return Calculate(multi.Mods); yield return Calculate(multi.Mods);
else else
yield return Calculate(combination); yield return Calculate(new[] { combination });
} }
} }
@ -112,11 +115,12 @@ namespace osu.Game.Rulesets.Difficulty
/// Performs required tasks before every calculation. /// Performs required tasks before every calculation.
/// </summary> /// </summary>
/// <param name="mods">The original list of <see cref="Mod"/>s.</param> /// <param name="mods">The original list of <see cref="Mod"/>s.</param>
private void preProcess(Mod[] mods) /// <param name="cancellationToken">The cancellation cancellationToken.</param>
private void preProcess(IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{ {
playableMods = mods.Select(m => m.DeepClone()).ToArray(); playableMods = mods.Select(m => m.DeepClone()).ToArray();
Beatmap = beatmap.GetPlayableBeatmap(ruleset.RulesetInfo, playableMods); Beatmap = beatmap.GetPlayableBeatmap(ruleset.RulesetInfo, playableMods, cancellationToken: cancellationToken);
var track = new TrackVirtual(10000); var track = new TrackVirtual(10000);
playableMods.OfType<IApplicableToTrack>().ForEach(m => m.ApplyToTrack(track)); playableMods.OfType<IApplicableToTrack>().ForEach(m => m.ApplyToTrack(track));

View File

@ -284,7 +284,7 @@ namespace osu.Game.Stores
decoded.BeatmapInfo.Ruleset = rulesetInstance.RulesetInfo; decoded.BeatmapInfo.Ruleset = rulesetInstance.RulesetInfo;
// TODO: this should be done in a better place once we actually need to dynamically update it. // TODO: this should be done in a better place once we actually need to dynamically update it.
beatmap.StarRating = rulesetInstance.CreateDifficultyCalculator(new DummyConversionBeatmap(decoded)).Calculate().StarRating; beatmap.StarRating = rulesetInstance.CreateDifficultyCalculator(new DummyConversionBeatmap(decoded)).Calculate(null).StarRating;
beatmap.Length = calculateLength(decoded); beatmap.Length = calculateLength(decoded);
beatmap.BPM = 60000 / decoded.GetMostCommonBeatLength(); beatmap.BPM = 60000 / decoded.GetMostCommonBeatLength();
} }