|
|
|
@ -1559,11 +1559,11 @@ class MultiModalVisualAgent:
|
|
|
|
|
def clear_memory(self):
|
|
|
|
|
self.memory.clear()
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if not os.path.exists("checkpoints"):
|
|
|
|
|
os.mkdir("checkpoints")
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
|
|
|
|
agent = MultiModalVisualAgent(load_dict=load_dict)
|
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
# if not os.path.exists("checkpoints"):
|
|
|
|
|
# os.mkdir("checkpoints")
|
|
|
|
|
# parser = argparse.ArgumentParser()
|
|
|
|
|
# parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
|
|
|
|
|
# args = parser.parse_args()
|
|
|
|
|
# load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
|
|
|
|
# agent = MultiModalVisualAgent(load_dict=load_dict)
|