Files
DDT/tools/mm.py
wangshuai6 06499f1caa submit code
2025-04-09 11:01:16 +08:00

23 lines
681 B
Python

import torch
import time
import torch.nn as nn
import accelerate
if __name__ == "__main__":
model = nn.Linear(512, 512)
for p in model.parameters():
p.requires_grad = False
accelerator = accelerate.Accelerator()
model = accelerator.prepare_model(model)
model.to(accelerator.device)
data = torch.randn(1024, 512).to(accelerator.device)
while True:
time.sleep(0.01)
accelerator.wait_for_everyone()
if torch.cuda.utilization() < 1.5:
with torch.no_grad():
model(data)
else:
time.sleep(1)
# print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}")