@@ -356,12 +356,26 @@ struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
umem->owning_mm = root->umem.owning_mm;
odp_data->page_shift = PAGE_SHIFT;
- ret = ib_init_umem_odp(odp_data);
- if (ret) {
- kfree(odp_data);
- return ERR_PTR(ret);
+ /*
+ * A mmget must be held when registering a notifier, the owming_mm only
+ * has a mm_grab at this point.
+ */
+ if (!mmget_not_zero(umem->owning_mm)) {
+ ret = -EFAULT;
+ goto out_free;
}
+
+ ret = ib_init_umem_odp(odp_data);
+ if (ret)
+ goto out_mmput;
+ mmput(umem->owning_mm);
return odp_data;
+
+out_mmput:
+ mmput(umem->owning_mm);
+out_free:
+ kfree(odp_data);
+ return ERR_PTR(ret);
}
EXPORT_SYMBOL(ib_umem_odp_alloc_child);