diff --git a/osu.Game.Tests.Android/osu.Game.Tests.Android.csproj b/osu.Game.Tests.Android/osu.Game.Tests.Android.csproj index c44ed69c4d..19e36a63f1 100644 --- a/osu.Game.Tests.Android/osu.Game.Tests.Android.csproj +++ b/osu.Game.Tests.Android/osu.Game.Tests.Android.csproj @@ -69,5 +69,9 @@ osu.Game + + + + \ No newline at end of file diff --git a/osu.Game.Tests.iOS/osu.Game.Tests.iOS.csproj b/osu.Game.Tests.iOS/osu.Game.Tests.iOS.csproj index ca68369ebb..67b2298f4c 100644 --- a/osu.Game.Tests.iOS/osu.Game.Tests.iOS.csproj +++ b/osu.Game.Tests.iOS/osu.Game.Tests.iOS.csproj @@ -45,6 +45,7 @@ + \ No newline at end of file diff --git a/osu.Game.Tests/Mods/ModUtilsTest.cs b/osu.Game.Tests/Mods/ModUtilsTest.cs new file mode 100644 index 0000000000..e4ded602aa --- /dev/null +++ b/osu.Game.Tests/Mods/ModUtilsTest.cs @@ -0,0 +1,110 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using Moq; +using NUnit.Framework; +using osu.Game.Rulesets.Mods; +using osu.Game.Utils; + +namespace osu.Game.Tests.Mods +{ + [TestFixture] + public class ModUtilsTest + { + [Test] + public void TestModIsCompatibleByItself() + { + var mod = new Mock(); + Assert.That(ModUtils.CheckCompatibleSet(new[] { mod.Object })); + } + + [Test] + public void TestIncompatibleThroughTopLevel() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + + mod1.Setup(m => m.IncompatibleMods).Returns(new[] { mod2.Object.GetType() }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); + } + + [Test] + public void TestMultiModIncompatibleWithTopLevel() + { + var mod1 = new Mock(); + + // The nested mod. + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { mod1.Object.GetType() }); + + var multiMod = new MultiMod(new MultiMod(mod2.Object)); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { multiMod, mod1.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, multiMod }), Is.False); + } + + [Test] + public void TestTopLevelIncompatibleWithMultiMod() + { + // The nested mod. + var mod1 = new Mock(); + var multiMod = new MultiMod(new MultiMod(mod1.Object)); + + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(CustomMod1) }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { multiMod, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, multiMod }), Is.False); + } + + [Test] + public void TestCompatibleMods() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.True); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.True); + } + + [Test] + public void TestIncompatibleThroughBaseType() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(Mod) }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); + } + + [Test] + public void TestAllowedThroughMostDerivedType() + { + var mod = new Mock(); + Assert.That(ModUtils.CheckAllowed(new[] { mod.Object }, new[] { mod.Object.GetType() })); + } + + [Test] + public void TestNotAllowedThroughBaseType() + { + var mod = new Mock(); + Assert.That(ModUtils.CheckAllowed(new[] { mod.Object }, new[] { typeof(Mod) }), Is.False); + } + + public abstract class CustomMod1 : Mod + { + } + + public abstract class CustomMod2 : Mod + { + } + } +} diff --git a/osu.Game/Utils/ModUtils.cs b/osu.Game/Utils/ModUtils.cs index 808dba2900..9336add465 100644 --- a/osu.Game/Utils/ModUtils.cs +++ b/osu.Game/Utils/ModUtils.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; -using osu.Framework.Extensions.TypeExtensions; using osu.Game.Rulesets.Mods; #nullable enable @@ -29,7 +28,7 @@ namespace osu.Game.Utils { // Prevent multiple-enumeration. var combinationList = combination as ICollection ?? combination.ToArray(); - return CheckCompatibleSet(combinationList) && CheckAllowed(combinationList, allowedTypes); + return CheckCompatibleSet(combinationList, out _) && CheckAllowed(combinationList, allowedTypes); } /// @@ -38,32 +37,32 @@ namespace osu.Game.Utils /// The combination to check. /// Whether all s in the combination are compatible with each-other. public static bool CheckCompatibleSet(IEnumerable combination) + => CheckCompatibleSet(combination, out _); + + /// + /// Checks that all s in a combination are compatible with each-other. + /// + /// The combination to check. + /// Any invalid mods in the set. + /// Whether all s in the combination are compatible with each-other. + public static bool CheckCompatibleSet(IEnumerable combination, out List? invalidMods) { - var incompatibleTypes = new HashSet(); - var incomingTypes = new HashSet(); + combination = FlattenMods(combination).ToArray(); + invalidMods = null; - foreach (var mod in combination.SelectMany(FlattenMod)) + foreach (var mod in combination) { - // Add the new mod incompatibilities, checking whether any match the existing mod types. - foreach (var t in mod.IncompatibleMods) + foreach (var type in mod.IncompatibleMods) { - if (incomingTypes.Contains(t)) - return false; - - incompatibleTypes.Add(t); - } - - // Add the new mod types, checking whether any match the incompatible types. - foreach (var t in mod.GetType().EnumerateBaseTypes()) - { - if (incomingTypes.Contains(t)) - return false; - - incomingTypes.Add(t); + foreach (var invalid in combination.Where(m => type.IsInstanceOfType(m))) + { + invalidMods ??= new List(); + invalidMods.Add(invalid); + } } } - return true; + return invalidMods == null; } ///