1234567891011121314151617 |
- import torch
- import sys
- if __name__ == "__main__":
- inpath = sys.argv[1]
- outpath = sys.argv[2]
- submodel = "cond_stage_model"
- if len(sys.argv) > 3:
- submodel = sys.argv[3]
- print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
- sd = torch.load(inpath, map_location="cpu")
- new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
- for k,v in sd["state_dict"].items()
- if k.startswith("cond_stage_model"))}
- torch.save(new_sd, outpath)
|