当前位置: 首页 > 工具软件 > dict_build > 使用案例 >

mmlab的build 和 iter_based_runner:

计和顺
2023-12-01
def build(cfg, registry, default_args=None):
    """Build module function.

    Args:
        cfg (dict): Configuration for building modules.
        registry (obj): ``registry`` object.
        default_args (dict, optional): Default arguments. Defaults to None.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)

    return build_from_cfg(cfg, registry, default_args)


def build_backbone(cfg):
    """Build backbone.

    Args:
        cfg (dict): Configuration for building backbone.
    """
    return build(cfg, BACKBONES)


def build_component(cfg):
    """Build component.

    Args:
        cfg (dict): Configuration for building component.
    """
    return build(cfg, COMPONENTS)


def build_loss(cfg):
    """Build loss.

    Args:
        cfg (dict): Configuration for building loss.
    """
    return build(cfg, LOSSES)


def build_model(cfg, train_cfg=None, test_cfg=None):
    """Build model.

    Args:
        cfg (dict): Configuration for building model.
        train_cfg (dict): Training configuration. Default: None.
        test_cfg (dict): Testing configuration. Default: None.
    """
    return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

如上,有这么多个bulid 分别是build model, build loss , build componet, build backbone 

这个mmeding中,model和这个backbone的区别:model就是相当于啥都搞好了,它的代码段经常包含一个train_step。而这个backbone就纯粹的网络模型设计。

并且,每次build都要经过build再到build_from_cfg。然后,通过register这个挂钩来关联的。

今天又一个路径问题:我喜欢先用终端指令去run下,然后再debug。但是这时候又不把;原来的config文件改成绝对路径的还是用那个相对路径。。。

 

关于mmediting,它的bulid register。虽然 它的解耦性很好,但是对我这种代码能力差的就改起来不友好。大概也知道 这种东西是用来干嘛的,可是你要加一个新网络进去,就有很多地方配置它的Model.register之类的东西,而且 好像现在还报错。。。。

它这个调用训练的过程,就是mmedit里面的restore 的train_step调用 然后 再调用对应网络去forward;;;

关于 每次把装载的数据dataloader 怎么一个一个迭代喂到网络中:这个代码都在这个apis/train.py里面的这个builder runner这步骤里。

然后,IterBasedRunner这个类里面的方法有一个run方法就是真正的进行train的过程:

self.call_hook(before_run) 这个before_run 调这个的钩子会进入的网络net里面

iter_loaders = [Iterloader(x) for x in data_loaders]  这个 我感觉是把在dataloaders的数据给到这个每次迭代的iter_loaders里面。

然后

在epoch的过程中,有一行代码:

iter_runner = getattr(self, mode) 这个代码就是获得这个类的方法 Iter_based_runner的train方法。

然后用这个iter_runner对这个iter_loaders进行处理。。

 类似资料: