@@ -200,7 +200,7 @@ static int __iommu_probe_device(struct device *dev, struct list_head *group_list
dev->iommu->iommu_dev = iommu_dev;
group = iommu_group_get_for_dev(dev);
- if (!IS_ERR(group)) {
+ if (IS_ERR(group)) {
ret = PTR_ERR(group);
goto out_release;
}
@@ -1600,6 +1600,37 @@ static int add_iommu_group(struct device *dev, void *data)
return ret;
}
+static int probe_iommu_group(struct device *dev, void *data)
+{
+ const struct iommu_ops *ops = dev->bus->iommu_ops;
+ struct list_head *group_list = data;
+ int ret;
+
+ if (!dev_iommu_get(dev))
+ return -ENOMEM;
+
+ if (!try_module_get(ops->owner)) {
+ ret = -EINVAL;
+ goto err_free_dev_iommu;
+ }
+
+ ret = __iommu_probe_device(dev, group_list);
+ if (ret)
+ goto err_module_put;
+
+ return 0;
+
+err_module_put:
+ module_put(ops->owner);
+err_free_dev_iommu:
+ dev_iommu_free(dev);
+
+ if (ret == -ENODEV)
+ ret = 0;
+
+ return ret;
+}
+
static int remove_iommu_group(struct device *dev, void *data)
{
iommu_release_device(dev);
@@ -1659,10 +1690,127 @@ static int iommu_bus_notifier(struct notifier_block *nb,
return 0;
}
+struct __group_domain_type {
+ struct device *dev;
+ unsigned int type;
+};
+
+static int probe_get_default_domain_type(struct device *dev, void *data)
+{
+ const struct iommu_ops *ops = dev->bus->iommu_ops;
+ struct __group_domain_type *gtype = data;
+ unsigned int type = 0;
+
+ if (ops->def_domain_type)
+ type = ops->def_domain_type(dev);
+
+ if (type) {
+ if (gtype->type && gtype->type != type) {
+ dev_warn(dev, "Device needs domain type %s, but device %s in the same iommu group requires type %s - using default\n",
+ iommu_domain_type_str(type),
+ dev_name(gtype->dev),
+ iommu_domain_type_str(gtype->type));
+ gtype->type = 0;
+ }
+
+ if (!gtype->dev) {
+ gtype->dev = dev;
+ gtype->type = type;
+ }
+ }
+
+ return 0;
+}
+
+static void probe_alloc_default_domain(struct bus_type *bus,
+ struct iommu_group *group)
+{
+ struct __group_domain_type gtype;
+
+ memset(>ype, 0, sizeof(gtype));
+
+ /* Ask for default domain requirements of all devices in the group */
+ __iommu_group_for_each_dev(group, >ype,
+ probe_get_default_domain_type);
+
+ if (!gtype.type)
+ gtype.type = iommu_def_domain_type;
+
+ iommu_group_alloc_default_domain(bus, group, gtype.type);
+}
+
+static int iommu_group_do_dma_attach(struct device *dev, void *data)
+{
+ struct iommu_domain *domain = data;
+ const struct iommu_ops *ops;
+ int ret;
+
+ ret = __iommu_attach_device(domain, dev);
+
+ ops = domain->ops;
+
+ if (ret == 0 && ops->probe_finalize)
+ ops->probe_finalize(dev);
+
+ return ret;
+}
+
+static int __iommu_group_dma_attach(struct iommu_group *group)
+{
+ return __iommu_group_for_each_dev(group, group->default_domain,
+ iommu_group_do_dma_attach);
+}
+
+static int bus_iommu_probe(struct bus_type *bus)
+{
+ const struct iommu_ops *ops = bus->iommu_ops;
+ int ret;
+
+ if (ops->probe_device) {
+ struct iommu_group *group, *next;
+ LIST_HEAD(group_list);
+
+ /*
+ * This code-path does not allocate the default domain when
+ * creating the iommu group, so do it after the groups are
+ * created.
+ */
+ ret = bus_for_each_dev(bus, NULL, &group_list, probe_iommu_group);
+ if (ret)
+ return ret;
+
+ list_for_each_entry_safe(group, next, &group_list, entry) {
+ /* Remove item from the list */
+ list_del_init(&group->entry);
+
+ mutex_lock(&group->mutex);
+
+ /* Try to allocate default domain */
+ probe_alloc_default_domain(bus, group);
+
+ if (!group->default_domain) {
+ mutex_unlock(&group->mutex);
+ continue;
+ }
+
+ ret = __iommu_group_dma_attach(group);
+
+ mutex_unlock(&group->mutex);
+
+ if (ret)
+ break;
+ }
+ } else {
+ ret = bus_for_each_dev(bus, NULL, NULL, add_iommu_group);
+ }
+
+ return ret;
+}
+
static int iommu_bus_init(struct bus_type *bus, const struct iommu_ops *ops)
{
- int err;
struct notifier_block *nb;
+ int err;
nb = kzalloc(sizeof(struct notifier_block), GFP_KERNEL);
if (!nb)
@@ -1674,7 +1822,7 @@ static int iommu_bus_init(struct bus_type *bus, const struct iommu_ops *ops)
if (err)
goto out_free;
- err = bus_for_each_dev(bus, NULL, NULL, add_iommu_group);
+ err = bus_iommu_probe(bus);
if (err)
goto out_err;