diff --git a/src/main/scala/monads/MonadFunctor.scala b/src/main/scala/monads/MonadFunctor.scala
index 789d5737e3613446f423af4e5e81ce33edcd901d..f85bbedcf14a515303f5fb954941421a37b626e6 100644
--- a/src/main/scala/monads/MonadFunctor.scala
+++ b/src/main/scala/monads/MonadFunctor.scala
@@ -2,18 +2,22 @@ package monads
 
 trait Monad[F[_]]:
   def pure[A](a: A): F[A]
-  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
 
-  def map2[A, B, C](fa: F[A], fb: F[B])(f: (A, B) => C): F[C] = ???
+  extension [A](fa: F[A])
+    def flatMap[B](f: A => F[B]): F[B]
 
-  def sequence[A](fas: List[F[A]]): F[List[A]] = ???
+    def map2[B, C](fb: F[B])(f: (A, B) => C): F[C] = ???
+
+  extension [A](fas: List[F[A]])
+    def sequence: F[List[A]] = ???
 
   def compose[A, B, C](f: A => F[B])(g: B => F[C]): A => F[C] = ???
 
 trait Functor[F[_]]:
-  def map[A, B](a: F[A])(f: A => B): F[B]
+  extension [A](fa: F[A])
+    def map[B](f: A => B): F[B]
 
 object Functor:
-  def functorFromMonad[F[_]](M: Monad[F]): Functor[F] = new Functor[F] {
-    def map[A, B](a: F[A])(f: A => B): F[B] = ???
-  }
+  def functorFromMonad[F[_]](using m: Monad[F]): Functor[F] = new Functor[F]:
+    extension [A](fa: F[A])
+      def map[B](f: A => B): F[B] = ???
diff --git a/src/main/scala/monads/MonadId.scala b/src/main/scala/monads/MonadId.scala
index 2c87fba791a938da05179b64b29c23247a824773..2d26720334174cc142d32d5d3392bc0a407da1ef 100644
--- a/src/main/scala/monads/MonadId.scala
+++ b/src/main/scala/monads/MonadId.scala
@@ -2,7 +2,8 @@ package monads
 final case class Id[A](value: A)
 
 object Id:
-    // No tests. If it compiles, it's correct.
-    given Monad[Id] with
-        def pure[A](a: A): Id[A] = ???
-        def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = ???
+// No tests. If it compiles, it's correct.
+  given Monad[Id] with
+    def pure[A](a: A): Id[A] = ???
+    extension [A](fa: Id[A])
+      def flatMap[B](f: A => Id[B]): Id[B] = ???
diff --git a/src/test/scala/monads/MonadFunctorSpec.scala b/src/test/scala/monads/MonadFunctorSpec.scala
index 03d05a3f83f53df80723dc1c2fe77ff0e56a67dd..0d50e7af67a4d6b8974548ce5e69bca1ccd7ed1a 100644
--- a/src/test/scala/monads/MonadFunctorSpec.scala
+++ b/src/test/scala/monads/MonadFunctorSpec.scala
@@ -5,11 +5,12 @@ import testutil.PendingIfUnimplemented
 import org.scalatest.flatspec.AnyFlatSpec
 import org.scalatest.matchers.should.Matchers
 
+given Monad[Option] with
+    def pure[A](a: A): Option[A] = Some(a)
+    extension [A](fa: Option[A]) def flatMap[B](f: A => Option[B]): Option[B] = fa.flatMap(f)
+
+
 class MonadFunctorSpec extends AnyFlatSpec with Matchers with AppendedClues with PendingIfUnimplemented:
   "functorFromMonad" should "return a working functor" in {
-    Functor.functorFromMonad[Option](new Monad[Option] {
-      def pure[A](a: A): Option[A] = Some(a)
-      def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] =
-        fa.flatMap(f)
-    }).map[Int, String](Some(3))(_.toString) shouldBe Some("3")
+    Functor.functorFromMonad[Option].map(Some(3))(_.toString) shouldBe Some("3")
   }
diff --git a/src/test/scala/monads/MonadSpec.scala b/src/test/scala/monads/MonadSpec.scala
index e806cba79290494db45042e2787c153f3b81f515..50343a966feb19bb860115d7b5a373c074e6b20e 100644
--- a/src/test/scala/monads/MonadSpec.scala
+++ b/src/test/scala/monads/MonadSpec.scala
@@ -7,26 +7,25 @@ import org.scalatest.matchers.should.Matchers
 
 class MonadSpec extends AnyFlatSpec with Matchers with AppendedClues with PendingIfUnimplemented:
   "map2" should "combine two monadic values" in {
-    tupleMonad.map2(Tuple1(1), Tuple1(2))((a, b) => a + b) shouldBe Tuple1(3)
+    Tuple1(1).map2(Tuple1(2))((a, b) => a + b) shouldBe Tuple1(3)
   }
 
   "sequence" should "correctly sequence all values" in {
-    tupleMonad.sequence(List(
+    List(
       Tuple1(1),
       Tuple1(2),
       Tuple1(3)
-    )) shouldBe Tuple1(List(1,2,3))
+    ).sequence shouldBe Tuple1(List(1,2,3))
   }
 
   "compose" should "correctly compose monad functions" in {
     val f: Int => Tuple1[Double] = a => Tuple1(a/2.0)
     val g: Double => Tuple1[String] = a => Tuple1(a.toString)
 
-    tupleMonad.compose(f)(g)(3) shouldBe Tuple1("1.5")
+    summon[Monad[Tuple1]].compose(f)(g)(3) shouldBe Tuple1("1.5")
   }
 
-  val tupleMonad = new Monad[Tuple1] {
+  given Monad[Tuple1] with
     def pure[A](a: A): Tuple1[A] = Tuple1(a)
-    def flatMap[A, B](fa: Tuple1[A])(f: A => Tuple1[B]): Tuple1[B] =
-      f(fa._1)
-  }
+    extension [A](fa: Tuple1[A])
+      def flatMap[B](f: A => Tuple1[B]): Tuple1[B] = f(fa._1)