diff --git a/osu.Game.Tests/Mods/ModUtilsTest.cs b/osu.Game.Tests/Mods/ModUtilsTest.cs index 7384471c41..9f27289d7e 100644 --- a/osu.Game.Tests/Mods/ModUtilsTest.cs +++ b/osu.Game.Tests/Mods/ModUtilsTest.cs @@ -21,6 +21,14 @@ namespace osu.Game.Tests.Mods Assert.That(ModUtils.CheckCompatibleSet(new[] { mod.Object })); } + [Test] + public void TestModIsCompatibleByItselfWithIncompatibleInterface() + { + var mod = new Mock(); + mod.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(IModCompatibilitySpecification) }); + Assert.That(ModUtils.CheckCompatibleSet(new[] { mod.Object })); + } + [Test] public void TestIncompatibleThroughTopLevel() { @@ -34,6 +42,20 @@ namespace osu.Game.Tests.Mods Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); } + [Test] + public void TestIncompatibleThroughInterface() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + + mod1.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(IModCompatibilitySpecification) }); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(IModCompatibilitySpecification) }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); + } + [Test] public void TestMultiModIncompatibleWithTopLevel() { @@ -149,11 +171,15 @@ namespace osu.Game.Tests.Mods Assert.That(invalid.Select(t => t.GetType()), Is.EquivalentTo(expectedInvalid)); } - public abstract class CustomMod1 : Mod + public abstract class CustomMod1 : Mod, IModCompatibilitySpecification { } - public abstract class CustomMod2 : Mod + public abstract class CustomMod2 : Mod, IModCompatibilitySpecification + { + } + + public interface IModCompatibilitySpecification { } } diff --git a/osu.Game/Utils/ModUtils.cs b/osu.Game/Utils/ModUtils.cs index 1c3558fc90..98766cb844 100644 --- a/osu.Game/Utils/ModUtils.cs +++ b/osu.Game/Utils/ModUtils.cs @@ -60,6 +60,9 @@ namespace osu.Game.Utils { foreach (var invalid in combination.Where(m => type.IsInstanceOfType(m))) { + if (invalid == mod) + continue; + invalidMods ??= new List(); invalidMods.Add(invalid); }