diff --git a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs index f020c2a805..4c9260f640 100644 --- a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs +++ b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs @@ -21,14 +21,14 @@ namespace osu.Game.Tests.Beatmaps.IO public class ImportBeatmapTest { [Test] - public void TestImportWhenClosed() + public async Task TestImportWhenClosed() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWhenClosed")) { try { - LoadOszIntoOsu(loadOsu(host)); + await LoadOszIntoOsu(loadOsu(host)); } finally { @@ -38,7 +38,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenDelete() + public async Task TestImportThenDelete() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenDelete")) @@ -47,7 +47,7 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); deleteBeatmapSet(imported, osu); } @@ -59,7 +59,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenImport() + public async Task TestImportThenImport() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenImport")) @@ -68,8 +68,8 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. Assert.IsTrue(imported.ID == importedSecondTime.ID); @@ -88,7 +88,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestRollbackOnFailure() + public async Task TestRollbackOnFailure() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestRollbackOnFailure")) @@ -104,7 +104,7 @@ namespace osu.Game.Tests.Beatmaps.IO manager.ItemAdded += (_, __) => fireCount++; manager.ItemRemoved += _ => fireCount++; - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); Assert.AreEqual(0, fireCount -= 1); @@ -132,7 +132,7 @@ namespace osu.Game.Tests.Beatmaps.IO Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count); // this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu. - manager.Import(breakTemp); + await manager.Import(breakTemp); // no events should be fired in the case of a rollback. Assert.AreEqual(0, fireCount); @@ -149,7 +149,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenImportDifferentHash() + public async Task TestImportThenImportDifferentHash() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenImportDifferentHash")) @@ -159,12 +159,12 @@ namespace osu.Game.Tests.Beatmaps.IO var osu = loadOsu(host); var manager = osu.Dependencies.Get(); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); imported.Hash += "-changed"; manager.Update(imported); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); Assert.IsTrue(imported.ID != importedSecondTime.ID); Assert.IsTrue(imported.Beatmaps.First().ID < importedSecondTime.Beatmaps.First().ID); @@ -181,7 +181,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenDeleteThenImport() + public async Task TestImportThenDeleteThenImport() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenDeleteThenImport")) @@ -190,11 +190,11 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); deleteBeatmapSet(imported, osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. Assert.IsTrue(imported.ID == importedSecondTime.ID); @@ -209,7 +209,7 @@ namespace osu.Game.Tests.Beatmaps.IO [TestCase(true)] [TestCase(false)] - public void TestImportThenDeleteThenImportWithOnlineIDMismatch(bool set) + public async Task TestImportThenDeleteThenImportWithOnlineIDMismatch(bool set) { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost($"TestImportThenDeleteThenImport-{set}")) @@ -218,7 +218,7 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); if (set) imported.OnlineBeatmapSetID = 1234; @@ -229,7 +229,7 @@ namespace osu.Game.Tests.Beatmaps.IO deleteBeatmapSet(imported, osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap has been reimported due to mismatch (even though hashes matched) Assert.IsTrue(imported.ID != importedSecondTime.ID); @@ -243,7 +243,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportWithDuplicateBeatmapIDs() + public async Task TestImportWithDuplicateBeatmapIDs() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWithDuplicateBeatmapID")) @@ -284,7 +284,7 @@ namespace osu.Game.Tests.Beatmaps.IO var manager = osu.Dependencies.Get(); - var imported = manager.Import(toImport); + var imported = await manager.Import(toImport); Assert.NotNull(imported); Assert.AreEqual(null, imported.Beatmaps[0].OnlineBeatmapID); @@ -330,7 +330,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportWhenFileOpen() + public async Task TestImportWhenFileOpen() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWhenFileOpen")) { @@ -339,7 +339,7 @@ namespace osu.Game.Tests.Beatmaps.IO var osu = loadOsu(host); var temp = TestResources.GetTestBeatmapForImport(); using (File.OpenRead(temp)) - osu.Dependencies.Get().Import(temp); + await osu.Dependencies.Get().Import(temp); ensureLoaded(osu); File.Delete(temp); Assert.IsFalse(File.Exists(temp), "We likely held a read lock on the file when we shouldn't"); @@ -351,13 +351,13 @@ namespace osu.Game.Tests.Beatmaps.IO } } - public static BeatmapSetInfo LoadOszIntoOsu(OsuGameBase osu, string path = null) + public static async Task LoadOszIntoOsu(OsuGameBase osu, string path = null) { var temp = path ?? TestResources.GetTestBeatmapForImport(); var manager = osu.Dependencies.Get(); - manager.Import(temp); + await manager.Import(temp); var imported = manager.GetAllUsableBeatmapSets(); diff --git a/osu.Game.Tests/Scores/IO/ImportScoreTest.cs b/osu.Game.Tests/Scores/IO/ImportScoreTest.cs index e39f18c3cd..4babb07213 100644 --- a/osu.Game.Tests/Scores/IO/ImportScoreTest.cs +++ b/osu.Game.Tests/Scores/IO/ImportScoreTest.cs @@ -23,13 +23,13 @@ namespace osu.Game.Tests.Scores.IO public class ImportScoreTest { [Test] - public void TestBasicImport() + public async Task TestBasicImport() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestBasicImport")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { @@ -43,7 +43,7 @@ namespace osu.Game.Tests.Scores.IO OnlineScoreID = 12345, }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.AreEqual(toImport.Rank, imported.Rank); Assert.AreEqual(toImport.TotalScore, imported.TotalScore); @@ -62,20 +62,20 @@ namespace osu.Game.Tests.Scores.IO } [Test] - public void TestImportMods() + public async Task TestImportMods() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportMods")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { Mods = new Mod[] { new OsuModHardRock(), new OsuModDoubleTime() }, }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.IsTrue(imported.Mods.Any(m => m is OsuModHardRock)); Assert.IsTrue(imported.Mods.Any(m => m is OsuModDoubleTime)); @@ -88,13 +88,13 @@ namespace osu.Game.Tests.Scores.IO } [Test] - public void TestImportStatistics() + public async Task TestImportStatistics() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportStatistics")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { @@ -105,7 +105,7 @@ namespace osu.Game.Tests.Scores.IO } }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.AreEqual(toImport.Statistics[HitResult.Perfect], imported.Statistics[HitResult.Perfect]); Assert.AreEqual(toImport.Statistics[HitResult.Miss], imported.Statistics[HitResult.Miss]); @@ -117,7 +117,7 @@ namespace osu.Game.Tests.Scores.IO } } - private ScoreInfo loadIntoOsu(OsuGameBase osu, ScoreInfo score) + private async Task loadIntoOsu(OsuGameBase osu, ScoreInfo score) { var beatmapManager = osu.Dependencies.Get(); @@ -125,20 +125,24 @@ namespace osu.Game.Tests.Scores.IO score.Ruleset = new OsuRuleset().RulesetInfo; var scoreManager = osu.Dependencies.Get(); - scoreManager.Import(score); + await scoreManager.Import(score); return scoreManager.GetAllUsableScores().First(); } - private OsuGameBase loadOsu(GameHost host) + private async Task loadOsu(GameHost host) { var osu = new OsuGameBase(); + +#pragma warning disable 4014 Task.Run(() => host.Run(osu)); +#pragma warning restore 4014 + waitForOrAssert(() => osu.IsLoaded, @"osu! failed to start in a reasonable amount of time"); var beatmapFile = TestResources.GetTestBeatmapForImport(); var beatmapManager = osu.Dependencies.Get(); - beatmapManager.Import(beatmapFile); + await beatmapManager.Import(beatmapFile); return osu; } diff --git a/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs b/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs index 7104a420a3..8b941e4633 100644 --- a/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs +++ b/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs @@ -72,7 +72,7 @@ namespace osu.Game.Tests.Visual.Background Dependencies.Cache(manager = new BeatmapManager(LocalStorage, factory, rulesets, null, audio, host, Beatmap.Default)); Dependencies.Cache(new OsuConfigManager(LocalStorage)); - manager.Import(TestResources.GetTestBeatmapForImport()); + manager.Import(TestResources.GetTestBeatmapForImport()).Wait(); Beatmap.SetDefault(); } diff --git a/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs b/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs index f59458ef8d..c361598354 100644 --- a/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs +++ b/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs @@ -32,7 +32,7 @@ namespace osu.Game.Tests.Visual.UserInterface this.api = api; this.rulesets = rulesets; - testBeatmap = ImportBeatmapTest.LoadOszIntoOsu(osu); + testBeatmap = ImportBeatmapTest.LoadOszIntoOsu(osu).Result; } [Test] diff --git a/osu.Game.Tests/osu.Game.Tests.csproj b/osu.Game.Tests/osu.Game.Tests.csproj index 11d70ee7be..403ae16885 100644 --- a/osu.Game.Tests/osu.Game.Tests.csproj +++ b/osu.Game.Tests/osu.Game.Tests.csproj @@ -18,4 +18,9 @@ + + + ..\..\..\..\..\usr\local\share\dotnet\sdk\NuGetFallbackFolder\microsoft.win32.registry\4.5.0\ref\netstandard2.0\Microsoft.Win32.Registry.dll + + \ No newline at end of file diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs index b6fe7f88fa..f5ef52a93b 100644 --- a/osu.Game/Beatmaps/BeatmapManager.cs +++ b/osu.Game/Beatmaps/BeatmapManager.cs @@ -2,10 +2,12 @@ // See the LICENCE file in the repository root for full licence text. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; using System.Linq.Expressions; +using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; using osu.Framework.Audio; @@ -14,6 +16,7 @@ using osu.Framework.Extensions; using osu.Framework.Graphics.Textures; using osu.Framework.Logging; using osu.Framework.Platform; +using osu.Framework.Threading; using osu.Game.Beatmaps.Formats; using osu.Game.Database; using osu.Game.IO.Archives; @@ -72,6 +75,8 @@ namespace osu.Game.Beatmaps private readonly List currentDownloads = new List(); + private readonly BeatmapUpdateQueue updateQueue; + public BeatmapManager(Storage storage, IDatabaseContextFactory contextFactory, RulesetStore rulesets, IAPIProvider api, AudioManager audioManager, GameHost host = null, WorkingBeatmap defaultBeatmap = null) : base(storage, contextFactory, new BeatmapStore(contextFactory), host) @@ -86,9 +91,11 @@ namespace osu.Game.Beatmaps beatmaps = (BeatmapStore)ModelStore; beatmaps.BeatmapHidden += b => BeatmapHidden?.Invoke(b); beatmaps.BeatmapRestored += b => BeatmapRestored?.Invoke(b); + + updateQueue = new BeatmapUpdateQueue(api); } - protected override void Populate(BeatmapSetInfo beatmapSet, ArchiveReader archive) + protected override async Task Populate(BeatmapSetInfo beatmapSet, ArchiveReader archive, CancellationToken cancellationToken = default) { if (archive != null) beatmapSet.Beatmaps = createBeatmapDifficulties(archive); @@ -104,8 +111,7 @@ namespace osu.Game.Beatmaps validateOnlineIds(beatmapSet); - foreach (BeatmapInfo b in beatmapSet.Beatmaps) - fetchAndPopulateOnlineValues(b); + await Task.WhenAll(beatmapSet.Beatmaps.Select(b => updateQueue.Enqueue(new UpdateItem(b, cancellationToken)).Task).ToArray()); } protected override void PreImport(BeatmapSetInfo beatmapSet) @@ -181,10 +187,10 @@ namespace osu.Game.Beatmaps request.Success += filename => { - Task.Factory.StartNew(() => + Task.Factory.StartNew(async () => { // This gets scheduled back to the update thread, but we want the import to run in the background. - Import(downloadNotification, filename); + await Import(downloadNotification, filename); currentDownloads.Remove(request); }, TaskCreationOptions.LongRunning); }; @@ -381,47 +387,6 @@ namespace osu.Game.Beatmaps return beatmapInfos; } - /// - /// Query the API to populate missing values like OnlineBeatmapID / OnlineBeatmapSetID or (Rank-)Status. - /// - /// The beatmap to populate. - /// Whether to re-query if the provided beatmap already has populated values. - /// True if population was successful. - private bool fetchAndPopulateOnlineValues(BeatmapInfo beatmap, bool force = false) - { - if (api?.State != APIState.Online) - return false; - - if (!force && beatmap.OnlineBeatmapID != null && beatmap.BeatmapSet.OnlineBeatmapSetID != null - && beatmap.Status != BeatmapSetOnlineStatus.None && beatmap.BeatmapSet.Status != BeatmapSetOnlineStatus.None) - return true; - - Logger.Log("Attempting online lookup for the missing values...", LoggingTarget.Database); - - try - { - var req = new GetBeatmapRequest(beatmap); - - req.Perform(api); - - var res = req.Result; - - Logger.Log($"Successfully mapped to {res.OnlineBeatmapSetID} / {res.OnlineBeatmapID}.", LoggingTarget.Database); - - beatmap.Status = res.Status; - beatmap.BeatmapSet.Status = res.BeatmapSet.Status; - beatmap.BeatmapSet.OnlineBeatmapSetID = res.OnlineBeatmapSetID; - beatmap.OnlineBeatmapID = res.OnlineBeatmapID; - - return true; - } - catch (Exception e) - { - Logger.Log($"Failed ({e})", LoggingTarget.Database); - return false; - } - } - /// /// A dummy WorkingBeatmap for the purpose of retrieving a beatmap for star difficulty calculation. /// @@ -455,5 +420,111 @@ namespace osu.Game.Beatmaps public override bool IsImportant => false; } } + + private class BeatmapUpdateQueue + { + private readonly IAPIProvider api; + private readonly Queue queue = new Queue(); + + private int activeThreads; + + public BeatmapUpdateQueue(IAPIProvider api) + { + this.api = api; + } + + public UpdateItem Enqueue(UpdateItem item) + { + lock (queue) + { + queue.Enqueue(item); + + if (activeThreads >= 16) + return item; + + new Thread(runWork) { IsBackground = true }.Start(); + activeThreads++; + } + + return item; + } + + private void runWork() + { + while (true) + { + UpdateItem toProcess; + + lock (queue) + { + if (queue.Count == 0) + break; + + toProcess = queue.Dequeue(); + } + + toProcess.PerformUpdate(api); + } + + lock (queue) + activeThreads--; + } + } + + private class UpdateItem + { + public Task Task => tcs.Task; + + private readonly BeatmapInfo beatmap; + private readonly CancellationToken cancellationToken; + + private readonly TaskCompletionSource tcs = new TaskCompletionSource(); + + public UpdateItem(BeatmapInfo beatmap, CancellationToken cancellationToken) + { + this.beatmap = beatmap; + this.cancellationToken = cancellationToken; + } + + public void PerformUpdate(IAPIProvider api) + { + if (cancellationToken.IsCancellationRequested) + { + tcs.SetCanceled(); + return; + } + + if (api?.State != APIState.Online) + { + tcs.SetResult(false); + return; + } + + Logger.Log("Attempting online lookup for the missing values...", LoggingTarget.Database); + + var req = new GetBeatmapRequest(beatmap); + + req.Success += res => + { + Logger.Log($"Successfully mapped to {res.OnlineBeatmapSetID} / {res.OnlineBeatmapID}.", LoggingTarget.Database); + + beatmap.Status = res.Status; + beatmap.BeatmapSet.Status = res.BeatmapSet.Status; + beatmap.BeatmapSet.OnlineBeatmapSetID = res.OnlineBeatmapSetID; + beatmap.OnlineBeatmapID = res.OnlineBeatmapID; + + tcs.SetResult(true); + }; + + req.Failure += e => + { + Logger.Log($"Failed ({e})", LoggingTarget.Database); + + tcs.SetResult(false); + }; + + req.Perform(api); + } + } } } diff --git a/osu.Game/Database/ArchiveModelManager.cs b/osu.Game/Database/ArchiveModelManager.cs index 54dbae9ddc..afc614772e 100644 --- a/osu.Game/Database/ArchiveModelManager.cs +++ b/osu.Game/Database/ArchiveModelManager.cs @@ -5,14 +5,17 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; using osu.Framework; using osu.Framework.Extensions; +using osu.Framework.Extensions.TypeExtensions; using osu.Framework.IO.File; using osu.Framework.Logging; using osu.Framework.Platform; +using osu.Framework.Threading; using osu.Game.IO; using osu.Game.IO.Archives; using osu.Game.IPC; @@ -109,8 +112,11 @@ namespace osu.Game.Database a.Invoke(); } + private readonly ThreadedTaskScheduler importScheduler; + protected ArchiveModelManager(Storage storage, IDatabaseContextFactory contextFactory, MutableDatabaseBackedStoreWithFileIncludes modelStore, IIpcHost importHost = null) { + importScheduler = new ThreadedTaskScheduler(16, $"{GetType().ReadableName()}.Import"); ContextFactory = contextFactory; ModelStore = modelStore; @@ -130,92 +136,84 @@ namespace osu.Game.Database /// This will post notifications tracking progress. /// /// One or more archive locations on disk. - public void Import(params string[] paths) + public async Task Import(params string[] paths) { var notification = new ProgressNotification { State = ProgressNotificationState.Active }; PostNotification?.Invoke(notification); - Import(notification, paths); + + await Import(notification, paths); } - protected void Import(ProgressNotification notification, params string[] paths) + protected async Task Import(ProgressNotification notification, params string[] paths) { notification.Progress = 0; notification.Text = "Import is initialising..."; var term = $"{typeof(TModel).Name.Replace("Info", "").ToLower()}"; - List imported = new List(); + var tasks = new List(); int current = 0; foreach (string path in paths) { - if (notification.State == ProgressNotificationState.Cancelled) - // user requested abort - return; - - try + tasks.Add(Import(path, notification.CancellationToken).ContinueWith(t => { - var text = "Importing "; - - if (path.Length > 1) - text += $"{++current} of {paths.Length} {term}s.."; - else - text += $"{term}.."; - - // only show the filename if it isn't a temporary one (as those look ugly). - if (!path.Contains(Path.GetTempPath())) - text += $"\n{Path.GetFileName(path)}"; - - notification.Text = text; - - imported.Add(Import(path)); - - notification.Progress = (float)current / paths.Length; - } - catch (Exception e) - { - e = e.InnerException ?? e; - Logger.Error(e, $@"Could not import ({Path.GetFileName(path)})"); - } - } - - if (imported.Count == 0) - { - notification.Text = "Import failed!"; - notification.State = ProgressNotificationState.Cancelled; - } - else - { - notification.CompletionText = imported.Count == 1 - ? $"Imported {imported.First()}!" - : $"Imported {current} {term}s!"; - - if (imported.Count > 0 && PresentImport != null) - { - notification.CompletionText += " Click to view."; - notification.CompletionClickAction = () => + lock (notification) { - PresentImport?.Invoke(imported); - return true; - }; - } + current++; - notification.State = ProgressNotificationState.Completed; + notification.Text = $"Imported {current} of {paths.Length} {term}s"; + notification.Progress = (float)current / paths.Length; + } + + if (t.Exception != null) + { + var e = t.Exception.InnerException ?? t.Exception; + Logger.Error(e, $@"Could not import ({Path.GetFileName(path)})"); + } + })); } + + await Task.WhenAll(tasks); + + // if (imported.Count == 0) + // { + // notification.Text = "Import failed!"; + // notification.State = ProgressNotificationState.Cancelled; + // } + // else + // { + // notification.CompletionText = imported.Count == 1 + // ? $"Imported {imported.First()}!" + // : $"Imported {current} {term}s!"; + // + // if (imported.Count > 0 && PresentImport != null) + // { + // notification.CompletionText += " Click to view."; + // notification.CompletionClickAction = () => + // { + // PresentImport?.Invoke(imported); + // return true; + // }; + // } + // + // notification.State = ProgressNotificationState.Completed; + // } } /// /// Import one from the filesystem and delete the file on success. /// /// The archive location on disk. + /// An optional cancellation token. /// The imported model, if successful. - public TModel Import(string path) + public async Task Import(string path, CancellationToken cancellationToken = default) { TModel import; using (ArchiveReader reader = getReaderFrom(path)) - import = Import(reader); + import = await Import(reader, cancellationToken); // We may or may not want to delete the file depending on where it is stored. // e.g. reconstructing/repairing database with items from default storage. @@ -243,7 +241,8 @@ namespace osu.Game.Database /// Import an item from an . /// /// The archive to be imported. - public TModel Import(ArchiveReader archive) + /// An optional cancellation token. + public async Task Import(ArchiveReader archive, CancellationToken cancellationToken = default) { try { @@ -253,7 +252,7 @@ namespace osu.Game.Database model.Hash = computeHash(archive); - return Import(model, archive); + return await Import(model, archive, cancellationToken); } catch (Exception e) { @@ -288,7 +287,8 @@ namespace osu.Game.Database /// /// The model to be imported. /// An optional archive to use for model population. - public TModel Import(TModel item, ArchiveReader archive = null) + /// An optional cancellation token. + public async Task Import(TModel item, ArchiveReader archive = null, CancellationToken cancellationToken = default) => await Task.Factory.StartNew(async () => { delayEvents(); @@ -296,17 +296,31 @@ namespace osu.Game.Database { Logger.Log($"Importing {item}...", LoggingTarget.Database); + if (archive != null) + item.Files = createFileInfos(archive, Files); + + var localItem = item; + + try + { + await Populate(item, archive, cancellationToken); + } + catch (TaskCanceledException) + { + return item = null; + } + finally + { + if (!Delete(localItem)) + Files.Dereference(localItem.Files.Select(f => f.FileInfo).ToArray()); + } + using (var write = ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. { try { if (!write.IsTransactionLeader) throw new InvalidOperationException($"Ensure there is no parent transaction so errors can correctly be handled by {this}"); - if (archive != null) - item.Files = createFileInfos(archive, Files); - - Populate(item, archive); - var existing = CheckForExisting(item); if (existing != null) @@ -332,6 +346,9 @@ namespace osu.Game.Database } catch (Exception e) { + if (!Delete(item)) + Files.Dereference(item.Files.Select(f => f.FileInfo).ToArray()); + write.Errors.Add(e); throw; } @@ -351,7 +368,7 @@ namespace osu.Game.Database } return item; - } + }, CancellationToken.None, TaskCreationOptions.None, importScheduler).Unwrap(); /// /// Perform an update of the specified item. @@ -516,24 +533,24 @@ namespace osu.Game.Database /// /// This is a temporary method and will likely be replaced by a full-fledged (and more correctly placed) migration process in the future. /// - public Task ImportFromStableAsync() + public async Task ImportFromStableAsync() { var stable = GetStableStorage?.Invoke(); if (stable == null) { Logger.Log("No osu!stable installation available!", LoggingTarget.Information, LogLevel.Error); - return Task.CompletedTask; + return; } if (!stable.ExistsDirectory(ImportFromStablePath)) { // This handles situations like when the user does not have a Skins folder Logger.Log($"No {ImportFromStablePath} folder available in osu!stable installation", LoggingTarget.Information, LogLevel.Error); - return Task.CompletedTask; + return; } - return Task.Factory.StartNew(() => Import(stable.GetDirectories(ImportFromStablePath).Select(f => stable.GetFullPath(f)).ToArray()), TaskCreationOptions.LongRunning); + await Task.Run(async () => await Import(stable.GetDirectories(ImportFromStablePath).Select(f => stable.GetFullPath(f)).ToArray())); } #endregion @@ -552,9 +569,8 @@ namespace osu.Game.Database /// /// The model to populate. /// The archive to use as a reference for population. May be null. - protected virtual void Populate(TModel model, [CanBeNull] ArchiveReader archive) - { - } + /// An optional cancellation token. + protected virtual async Task Populate(TModel model, [CanBeNull] ArchiveReader archive, CancellationToken cancellationToken = default) => await Task.CompletedTask; /// /// Perform any final actions before the import to database executes. diff --git a/osu.Game/Database/ICanAcceptFiles.cs b/osu.Game/Database/ICanAcceptFiles.cs index f55d0c389e..b9f882468d 100644 --- a/osu.Game/Database/ICanAcceptFiles.cs +++ b/osu.Game/Database/ICanAcceptFiles.cs @@ -1,6 +1,8 @@ // Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. +using System.Threading.Tasks; + namespace osu.Game.Database { /// @@ -12,7 +14,7 @@ namespace osu.Game.Database /// Import the specified paths. /// /// The files which should be imported. - void Import(params string[] paths); + Task Import(params string[] paths); /// /// An array of accepted file extensions (in the standard format of ".abc"). diff --git a/osu.Game/IPC/ArchiveImportIPCChannel.cs b/osu.Game/IPC/ArchiveImportIPCChannel.cs index fc747cd446..484db932f8 100644 --- a/osu.Game/IPC/ArchiveImportIPCChannel.cs +++ b/osu.Game/IPC/ArchiveImportIPCChannel.cs @@ -38,7 +38,7 @@ namespace osu.Game.IPC } if (importer.HandledExtensions.Contains(Path.GetExtension(path)?.ToLowerInvariant())) - importer.Import(path); + await importer.Import(path); } } diff --git a/osu.Game/OsuGameBase.cs b/osu.Game/OsuGameBase.cs index f9128687d6..d6b8caaf5b 100644 --- a/osu.Game/OsuGameBase.cs +++ b/osu.Game/OsuGameBase.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; +using System.Threading.Tasks; using osu.Framework.Allocation; using osu.Framework.Audio; using osu.Framework.Bindables; @@ -268,13 +269,13 @@ namespace osu.Game private readonly List fileImporters = new List(); - public void Import(params string[] paths) + public async Task Import(params string[] paths) { var extension = Path.GetExtension(paths.First())?.ToLowerInvariant(); foreach (var importer in fileImporters) if (importer.HandledExtensions.Contains(extension)) - importer.Import(paths); + await importer.Import(paths); } public string[] HandledExtensions => fileImporters.SelectMany(i => i.HandledExtensions).ToArray(); diff --git a/osu.Game/Overlays/Notifications/ProgressNotification.cs b/osu.Game/Overlays/Notifications/ProgressNotification.cs index 857a0bda9e..c8e081d29f 100644 --- a/osu.Game/Overlays/Notifications/ProgressNotification.cs +++ b/osu.Game/Overlays/Notifications/ProgressNotification.cs @@ -2,6 +2,7 @@ // See the LICENCE file in the repository root for full licence text. using System; +using System.Threading; using osu.Framework.Allocation; using osu.Framework.Graphics; using osu.Framework.Graphics.Containers; @@ -36,6 +37,10 @@ namespace osu.Game.Overlays.Notifications State = state; } + private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + + public CancellationToken CancellationToken => cancellationTokenSource.Token; + public virtual ProgressNotificationState State { get => state; @@ -62,6 +67,8 @@ namespace osu.Game.Overlays.Notifications break; case ProgressNotificationState.Cancelled: + cancellationTokenSource.Cancel(); + Light.Colour = colourCancelled; Light.Pulsate = false; progressBar.Active = false; diff --git a/osu.Game/Screens/Menu/Intro.cs b/osu.Game/Screens/Menu/Intro.cs index 98a2fe8f13..cf5d247482 100644 --- a/osu.Game/Screens/Menu/Intro.cs +++ b/osu.Game/Screens/Menu/Intro.cs @@ -66,7 +66,7 @@ namespace osu.Game.Screens.Menu if (setInfo == null) { // we need to import the default menu background beatmap - setInfo = beatmaps.Import(new ZipArchiveReader(game.Resources.GetStream(@"Tracks/circles.osz"), "circles.osz")); + setInfo = beatmaps.Import(new ZipArchiveReader(game.Resources.GetStream(@"Tracks/circles.osz"), "circles.osz")).Result; setInfo.Protected = true; beatmaps.Update(setInfo); diff --git a/osu.Game/Screens/Play/Player.cs b/osu.Game/Screens/Play/Player.cs index d8389fa6d9..949f9e5ff2 100644 --- a/osu.Game/Screens/Play/Player.cs +++ b/osu.Game/Screens/Play/Player.cs @@ -279,7 +279,7 @@ namespace osu.Game.Screens.Play var score = CreateScore(); if (DrawableRuleset.ReplayScore == null) - scoreManager.Import(score); + scoreManager.Import(score).Wait(); this.Push(CreateResults(score)); diff --git a/osu.Game/Skinning/SkinManager.cs b/osu.Game/Skinning/SkinManager.cs index 3a4d44f608..73cc47ea47 100644 --- a/osu.Game/Skinning/SkinManager.cs +++ b/osu.Game/Skinning/SkinManager.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; using osu.Framework.Audio; using osu.Framework.Audio.Sample; @@ -71,9 +73,9 @@ namespace osu.Game.Skinning protected override SkinInfo CreateModel(ArchiveReader archive) => new SkinInfo { Name = archive.Name }; - protected override void Populate(SkinInfo model, ArchiveReader archive) + protected override async Task Populate(SkinInfo model, ArchiveReader archive, CancellationToken cancellationToken = default) { - base.Populate(model, archive); + await base.Populate(model, archive, cancellationToken); Skin reference = getSkin(model);