submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

26
src/models/conditioner.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class BaseConditioner(nn.Module):
def __init__(self):
super(BaseConditioner, self).__init__()
def _impl_condition(self, y):
...
def _impl_uncondition(self, y):
...
def __call__(self, y):
condition = self._impl_condition(y)
uncondition = self._impl_uncondition(y)
return condition, uncondition
class LabelConditioner(BaseConditioner):
def __init__(self, null_class):
super().__init__()
self.null_condition = null_class
def _impl_condition(self, y):
return torch.tensor(y).long().cuda()
def _impl_uncondition(self, y):
return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()