23 lines
681 B
Python
23 lines
681 B
Python
import torch
|
|
import time
|
|
import torch.nn as nn
|
|
import accelerate
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = nn.Linear(512, 512)
|
|
for p in model.parameters():
|
|
p.requires_grad = False
|
|
accelerator = accelerate.Accelerator()
|
|
model = accelerator.prepare_model(model)
|
|
model.to(accelerator.device)
|
|
data = torch.randn(1024, 512).to(accelerator.device)
|
|
while True:
|
|
time.sleep(0.01)
|
|
accelerator.wait_for_everyone()
|
|
if torch.cuda.utilization() < 1.5:
|
|
with torch.no_grad():
|
|
model(data)
|
|
else:
|
|
time.sleep(1)
|
|
# print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}") |