From 3ec7dc3bb94972fe16a7359db595faf8136bf5fa Mon Sep 17 00:00:00 2001 From: Dean Herbert Date: Sun, 4 Jul 2021 17:59:39 +0900 Subject: [PATCH] Update tests in line with thread safety check --- .../Database/TestRealmKeyBindingStore.cs | 54 +++++++++++-------- osu.Game/Database/RealmContextFactory.cs | 21 +++++--- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/osu.Game.Tests/Database/TestRealmKeyBindingStore.cs b/osu.Game.Tests/Database/TestRealmKeyBindingStore.cs index cac331451b..642ecf00b8 100644 --- a/osu.Game.Tests/Database/TestRealmKeyBindingStore.cs +++ b/osu.Game.Tests/Database/TestRealmKeyBindingStore.cs @@ -38,19 +38,28 @@ namespace osu.Game.Tests.Database [Test] public void TestDefaultsPopulationAndQuery() { - Assert.That(query().Count, Is.EqualTo(0)); + Assert.That(queryCount(), Is.EqualTo(0)); KeyBindingContainer testContainer = new TestKeyBindingContainer(); keyBindingStore.Register(testContainer); - Assert.That(query().Count, Is.EqualTo(3)); + Assert.That(queryCount(), Is.EqualTo(3)); - Assert.That(query().Where(k => k.ActionInt == (int)GlobalAction.Back).Count, Is.EqualTo(1)); - Assert.That(query().Where(k => k.ActionInt == (int)GlobalAction.Select).Count, Is.EqualTo(2)); + Assert.That(queryCount(GlobalAction.Back), Is.EqualTo(1)); + Assert.That(queryCount(GlobalAction.Select), Is.EqualTo(2)); } - private IQueryable query() => realmContextFactory.Context.All(); + private int queryCount(GlobalAction? match = null) + { + using (var usage = realmContextFactory.GetForRead()) + { + var results = usage.Realm.All(); + if (match.HasValue) + results = results.Where(k => k.ActionInt == (int)match.Value); + return results.Count(); + } + } [Test] public void TestUpdateViaQueriedReference() @@ -59,25 +68,28 @@ namespace osu.Game.Tests.Database keyBindingStore.Register(testContainer); - var backBinding = query().Single(k => k.ActionInt == (int)GlobalAction.Back); - - Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.Escape })); - - var tsr = ThreadSafeReference.Create(backBinding); - - using (var usage = realmContextFactory.GetForWrite()) + using (var primaryUsage = realmContextFactory.GetForRead()) { - var binding = usage.Realm.ResolveReference(tsr); - binding.KeyCombination = new KeyCombination(InputKey.BackSpace); + var backBinding = primaryUsage.Realm.All().Single(k => k.ActionInt == (int)GlobalAction.Back); - usage.Commit(); + Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.Escape })); + + var tsr = ThreadSafeReference.Create(backBinding); + + using (var usage = realmContextFactory.GetForWrite()) + { + var binding = usage.Realm.ResolveReference(tsr); + binding.KeyCombination = new KeyCombination(InputKey.BackSpace); + + usage.Commit(); + } + + Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.BackSpace })); + + // check still correct after re-query. + backBinding = primaryUsage.Realm.All().Single(k => k.ActionInt == (int)GlobalAction.Back); + Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.BackSpace })); } - - Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.BackSpace })); - - // check still correct after re-query. - backBinding = query().Single(k => k.ActionInt == (int)GlobalAction.Back); - Assert.That(backBinding.KeyCombination.Keys, Is.EquivalentTo(new[] { InputKey.BackSpace })); } [TearDown] diff --git a/osu.Game/Database/RealmContextFactory.cs b/osu.Game/Database/RealmContextFactory.cs index 3354b97849..f706c37419 100644 --- a/osu.Game/Database/RealmContextFactory.cs +++ b/osu.Game/Database/RealmContextFactory.cs @@ -4,6 +4,7 @@ using System; using System.Threading; using osu.Framework.Allocation; +using osu.Framework.Development; using osu.Framework.Graphics; using osu.Framework.Logging; using osu.Framework.Platform; @@ -46,15 +47,21 @@ namespace osu.Game.Database { get { - if (context == null) + if (!ThreadSafety.IsUpdateThread) + throw new InvalidOperationException($"Use {nameof(GetForRead)} when performing realm operations from a non-update thread"); + + lock (updateContextLock) { - context = createContext(); - Logger.Log($"Opened realm \"{context.Config.DatabasePath}\" at version {context.Config.SchemaVersion}"); + if (context == null) + { + context = createContext(); + Logger.Log($"Opened realm \"{context.Config.DatabasePath}\" at version {context.Config.SchemaVersion}"); + } + + // creating a context will ensure our schema is up-to-date and migrated. + + return context; } - - // creating a context will ensure our schema is up-to-date and migrated. - - return context; } }