Compare commits
10 Commits
3693640ca3
...
cf45afe325
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf45afe325 | ||
|
|
1d1b4d2913 | ||
|
|
ae7df43ecc | ||
|
|
2c16b8f423 | ||
|
|
8038c16bee | ||
|
|
598f7b40a2 | ||
|
|
3093a65151 | ||
|
|
910b764275 | ||
|
|
26c8e54edb | ||
|
|
99d92c94e7 |
@@ -13,6 +13,7 @@ We decouple diffusion transformer into encoder-decoder design, and surprisingly
|
|||||||
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
|
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
|
||||||
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
|
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
|
||||||
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
|
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
|
||||||
|
|
||||||
## Visualizations
|
## Visualizations
|
||||||

|

|
||||||
## Checkpoints
|
## Checkpoints
|
||||||
|
|||||||
116
configs/ddt_butterflies_b2_256.yaml
Normal file
116
configs/ddt_butterflies_b2_256.yaml
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
seed_everything: true
|
||||||
|
tags:
|
||||||
|
exp: &exp ddt_butterflies_b2_256
|
||||||
|
torch_hub_dir: null
|
||||||
|
huggingface_cache_dir: null
|
||||||
|
trainer:
|
||||||
|
default_root_dir: workdirs
|
||||||
|
accelerator: auto
|
||||||
|
strategy: auto
|
||||||
|
devices: auto
|
||||||
|
num_nodes: 1
|
||||||
|
precision: bf16-mixed
|
||||||
|
logger:
|
||||||
|
class_path: lightning.pytorch.loggers.WandbLogger
|
||||||
|
init_args:
|
||||||
|
project: ddt_butterflies
|
||||||
|
name: *exp
|
||||||
|
save_dir: workdirs
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
max_steps: 200000
|
||||||
|
val_check_interval: 2000
|
||||||
|
check_val_every_n_epoch: null
|
||||||
|
log_every_n_steps: 50
|
||||||
|
deterministic: null
|
||||||
|
inference_mode: true
|
||||||
|
use_distributed_sampler: false
|
||||||
|
callbacks:
|
||||||
|
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||||
|
init_args:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
save_top_k: -1
|
||||||
|
save_last: true
|
||||||
|
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||||
|
init_args:
|
||||||
|
save_dir: val
|
||||||
|
max_save_num: 64
|
||||||
|
model:
|
||||||
|
vae:
|
||||||
|
class_path: src.models.vae.LatentVAE
|
||||||
|
init_args:
|
||||||
|
precompute: false
|
||||||
|
weight_path: stabilityai/sd-vae-ft-ema
|
||||||
|
denoiser:
|
||||||
|
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
||||||
|
init_args:
|
||||||
|
in_channels: 4
|
||||||
|
patch_size: 2
|
||||||
|
num_groups: 12
|
||||||
|
hidden_size: &hidden_dim 768
|
||||||
|
num_blocks: 12
|
||||||
|
num_encoder_blocks: 8
|
||||||
|
num_classes: 1
|
||||||
|
conditioner:
|
||||||
|
class_path: src.models.conditioner.LabelConditioner
|
||||||
|
init_args:
|
||||||
|
null_class: 1
|
||||||
|
diffusion_trainer:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||||
|
init_args:
|
||||||
|
lognorm_t: true
|
||||||
|
encoder_weight_path: dinov2_vitb14
|
||||||
|
align_layer: 4
|
||||||
|
proj_denoiser_dim: *hidden_dim
|
||||||
|
proj_hidden_dim: *hidden_dim
|
||||||
|
proj_encoder_dim: 768
|
||||||
|
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
diffusion_sampler:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||||
|
init_args:
|
||||||
|
num_steps: 250
|
||||||
|
guidance: 1.0
|
||||||
|
timeshift: 1.0
|
||||||
|
state_refresh_rate: 1
|
||||||
|
guidance_interval_min: 0.3
|
||||||
|
guidance_interval_max: 1.0
|
||||||
|
scheduler: *scheduler
|
||||||
|
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||||
|
last_step: 0.04
|
||||||
|
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||||
|
ema_tracker:
|
||||||
|
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||||
|
init_args:
|
||||||
|
decay: 0.9999
|
||||||
|
optimizer:
|
||||||
|
class_path: torch.optim.AdamW
|
||||||
|
init_args:
|
||||||
|
lr: 1e-3
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.95
|
||||||
|
weight_decay: 0.0
|
||||||
|
data:
|
||||||
|
train_dataset: hf_image
|
||||||
|
train_root: ./data/butterflies
|
||||||
|
test_nature_root: null
|
||||||
|
test_gen_root: null
|
||||||
|
train_image_size: 256
|
||||||
|
train_batch_size: 128
|
||||||
|
train_num_workers: 8
|
||||||
|
train_prefetch_factor: 2
|
||||||
|
train_hf_name: huggan/smithsonian_butterflies_subset
|
||||||
|
train_hf_split: train
|
||||||
|
train_hf_image_column: image
|
||||||
|
train_hf_label_column: null
|
||||||
|
train_hf_cache_dir: null
|
||||||
|
eval_max_num_instances: 256
|
||||||
|
pred_batch_size: 32
|
||||||
|
pred_num_workers: 4
|
||||||
|
pred_seeds: null
|
||||||
|
pred_selected_classes: null
|
||||||
|
num_classes: 1
|
||||||
|
latent_shape:
|
||||||
|
- 4
|
||||||
|
- 32
|
||||||
|
- 32
|
||||||
@@ -7,3 +7,4 @@ torchvision
|
|||||||
timm
|
timm
|
||||||
accelerate
|
accelerate
|
||||||
gradio
|
gradio
|
||||||
|
datasets
|
||||||
|
|||||||
@@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler):
|
|||||||
self,
|
self,
|
||||||
order: int = 2,
|
order: int = 2,
|
||||||
timeshift: float = 1.0,
|
timeshift: float = 1.0,
|
||||||
|
guidance_interval_min: float = 0.0,
|
||||||
|
guidance_interval_max: float = 1.0,
|
||||||
lms_transform_fn: Callable = nop,
|
lms_transform_fn: Callable = nop,
|
||||||
w_scheduler: BaseScheduler = None,
|
|
||||||
step_fn: Callable = ode_step_fn,
|
step_fn: Callable = ode_step_fn,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.step_fn = step_fn
|
self.step_fn = step_fn
|
||||||
self.w_scheduler = w_scheduler
|
|
||||||
|
|
||||||
assert self.scheduler is not None
|
assert self.scheduler is not None
|
||||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
assert self.step_fn in [ode_step_fn, ]
|
||||||
self.order = order
|
self.order = order
|
||||||
self.lms_transform_fn = lms_transform_fn
|
self.lms_transform_fn = lms_transform_fn
|
||||||
|
|
||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
|
self.guidance_interval_min = guidance_interval_min
|
||||||
|
self.guidance_interval_max = guidance_interval_max
|
||||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||||
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
||||||
self._reparameterize_coeffs()
|
self._reparameterize_coeffs()
|
||||||
@@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler):
|
|||||||
cfg_x = torch.cat([x, x], dim=0)
|
cfg_x = torch.cat([x, x], dim=0)
|
||||||
cfg_t = t_cur.repeat(2)
|
cfg_t = t_cur.repeat(2)
|
||||||
out = net(cfg_x, cfg_t, cfg_condition)
|
out = net(cfg_x, cfg_t, cfg_condition)
|
||||||
out = self.guidance_fn(out, self.guidance)
|
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||||
|
guidance = self.guidance
|
||||||
|
out = self.guidance_fn(out, guidance)
|
||||||
|
else:
|
||||||
|
out = self.guidance_fn(out, 1.0)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
104
src/diffusion/flow_matching/training_disperse.py
Normal file
104
src/diffusion/flow_matching/training_disperse.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import timm
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from src.utils.no_grad import no_grad
|
||||||
|
from typing import Callable, Iterator, Tuple
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
from src.diffusion.base.training import *
|
||||||
|
from src.diffusion.base.scheduling import BaseScheduler
|
||||||
|
|
||||||
|
def inverse_sigma(alpha, sigma):
|
||||||
|
return 1/sigma**2
|
||||||
|
def snr(alpha, sigma):
|
||||||
|
return alpha/sigma
|
||||||
|
def minsnr(alpha, sigma, threshold=5):
|
||||||
|
return torch.clip(alpha/sigma, min=threshold)
|
||||||
|
def maxsnr(alpha, sigma, threshold=5):
|
||||||
|
return torch.clip(alpha/sigma, max=threshold)
|
||||||
|
def constant(alpha, sigma):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift_fn(t, timeshift=1.0):
|
||||||
|
return t/(t+(1-t)*timeshift)
|
||||||
|
|
||||||
|
|
||||||
|
class DisperseTrainer(BaseTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: BaseScheduler,
|
||||||
|
loss_weight_fn:Callable=constant,
|
||||||
|
feat_loss_weight: float=0.5,
|
||||||
|
lognorm_t=False,
|
||||||
|
timeshift=1.0,
|
||||||
|
align_layer=8,
|
||||||
|
temperature=1.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lognorm_t = lognorm_t
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.timeshift = timeshift
|
||||||
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
self.feat_loss_weight = feat_loss_weight
|
||||||
|
self.align_layer = align_layer
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||||
|
batch_size, c, height, width = x.shape
|
||||||
|
if self.lognorm_t:
|
||||||
|
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
|
||||||
|
else:
|
||||||
|
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32)
|
||||||
|
t = time_shift_fn(base_t, self.timeshift).to(x.dtype)
|
||||||
|
noise = torch.randn_like(x)
|
||||||
|
alpha = self.scheduler.alpha(t)
|
||||||
|
dalpha = self.scheduler.dalpha(t)
|
||||||
|
sigma = self.scheduler.sigma(t)
|
||||||
|
dsigma = self.scheduler.dsigma(t)
|
||||||
|
|
||||||
|
x_t = alpha * x + noise * sigma
|
||||||
|
v_t = dalpha * x + dsigma * noise
|
||||||
|
|
||||||
|
src_feature = []
|
||||||
|
def forward_hook(net, input, output):
|
||||||
|
feature = output
|
||||||
|
if isinstance(feature, tuple):
|
||||||
|
feature = feature[0] # mmdit
|
||||||
|
src_feature.append(feature)
|
||||||
|
|
||||||
|
if getattr(net, "encoder", None) is not None:
|
||||||
|
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||||
|
else:
|
||||||
|
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||||
|
|
||||||
|
out = net(x_t, t, y)
|
||||||
|
handle.remove()
|
||||||
|
disperse_loss = 0.0
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
for sf in src_feature:
|
||||||
|
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather(gathered_sf, sf)
|
||||||
|
gathered_sf = torch.cat(gathered_sf, dim=0)
|
||||||
|
sf = gathered_sf.view(batch_size * world_size, -1)
|
||||||
|
sf = sf.view(batch_size * world_size, -1)
|
||||||
|
# normalize sf
|
||||||
|
sf = sf / torch.norm(sf, dim=1, keepdim=True)
|
||||||
|
distance = torch.nn.functional.pdist(sf, p=2) ** 2
|
||||||
|
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
|
||||||
|
disperse_loss += sf_disperse_distance.mean().log()
|
||||||
|
|
||||||
|
|
||||||
|
weight = self.loss_weight_fn(alpha, sigma)
|
||||||
|
fm_loss = weight*(out - v_t)**2
|
||||||
|
|
||||||
|
out = dict(
|
||||||
|
fm_loss=fm_loss.mean(),
|
||||||
|
disperse_loss=disperse_loss.mean(),
|
||||||
|
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
|
||||||
|
)
|
||||||
|
return out
|
||||||
@@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler):
|
|||||||
self,
|
self,
|
||||||
order: int = 2,
|
order: int = 2,
|
||||||
timeshift: float = 1.0,
|
timeshift: float = 1.0,
|
||||||
|
guidance_interval_min: float = 0.0,
|
||||||
|
guidance_interval_max: float = 1.0,
|
||||||
state_refresh_rate: int = 1,
|
state_refresh_rate: int = 1,
|
||||||
lms_transform_fn: Callable = nop,
|
lms_transform_fn: Callable = nop,
|
||||||
w_scheduler: BaseScheduler = None,
|
|
||||||
step_fn: Callable = ode_step_fn,
|
step_fn: Callable = ode_step_fn,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.step_fn = step_fn
|
self.step_fn = step_fn
|
||||||
self.w_scheduler = w_scheduler
|
|
||||||
self.state_refresh_rate = state_refresh_rate
|
self.state_refresh_rate = state_refresh_rate
|
||||||
|
|
||||||
assert self.scheduler is not None
|
assert self.scheduler is not None
|
||||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
assert self.step_fn in [ode_step_fn, ]
|
||||||
self.order = order
|
self.order = order
|
||||||
self.lms_transform_fn = lms_transform_fn
|
self.lms_transform_fn = lms_transform_fn
|
||||||
|
self.guidance_interval_min = guidance_interval_min
|
||||||
|
self.guidance_interval_max = guidance_interval_max
|
||||||
|
|
||||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||||
@@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler):
|
|||||||
if i % self.state_refresh_rate == 0:
|
if i % self.state_refresh_rate == 0:
|
||||||
state = None
|
state = None
|
||||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||||
out = self.guidance_fn(out, self.guidance)
|
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||||
|
guidance = self.guidance
|
||||||
|
out = self.guidance_fn(out, guidance)
|
||||||
|
else:
|
||||||
|
out = self.guidance_fn(out, 1.0)
|
||||||
pred_trajectory.append(out)
|
pred_trajectory.append(out)
|
||||||
out = torch.zeros_like(out)
|
out = torch.zeros_like(out)
|
||||||
order = len(self.solver_coeffs[i])
|
order = len(self.solver_coeffs[i])
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
var_transform_engine: VARTransformEngine = None,
|
var_transform_engine: VARTransformEngine = None,
|
||||||
train_prefetch_factor=2,
|
train_prefetch_factor=2,
|
||||||
train_dataset: str = None,
|
train_dataset: str = None,
|
||||||
|
train_hf_name: str = None,
|
||||||
|
train_hf_split: str = "train",
|
||||||
|
train_hf_image_column: str = "image",
|
||||||
|
train_hf_label_column: str = None,
|
||||||
|
train_hf_cache_dir: str = None,
|
||||||
eval_batch_size=32,
|
eval_batch_size=32,
|
||||||
eval_num_workers=4,
|
eval_num_workers=4,
|
||||||
eval_max_num_instances=50000,
|
eval_max_num_instances=50000,
|
||||||
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_root = train_root
|
self.train_root = train_root
|
||||||
self.train_image_size = train_image_size
|
self.train_image_size = train_image_size
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
self.train_hf_name = train_hf_name
|
||||||
|
self.train_hf_split = train_hf_split
|
||||||
|
self.train_hf_image_column = train_hf_image_column
|
||||||
|
self.train_hf_label_column = train_hf_label_column
|
||||||
|
self.train_hf_cache_dir = train_hf_cache_dir
|
||||||
# stupid data_convert override, just to make nebular happy
|
# stupid data_convert override, just to make nebular happy
|
||||||
self.train_batch_size = train_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
self.train_num_workers = train_num_workers
|
self.train_num_workers = train_num_workers
|
||||||
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_dataset = PixImageNet512(
|
self.train_dataset = PixImageNet512(
|
||||||
root=self.train_root,
|
root=self.train_root,
|
||||||
)
|
)
|
||||||
|
elif self.train_dataset == "hf_image":
|
||||||
|
from src.data.dataset.hf_image import HuggingFaceImageDataset
|
||||||
|
if self.train_hf_name is None:
|
||||||
|
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
|
||||||
|
self.train_dataset = HuggingFaceImageDataset(
|
||||||
|
name=self.train_hf_name,
|
||||||
|
split=self.train_hf_split,
|
||||||
|
image_column=self.train_hf_image_column,
|
||||||
|
label_column=self.train_hf_label_column,
|
||||||
|
resolution=self.train_image_size,
|
||||||
|
cache_dir=self.train_hf_cache_dir,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("no such dataset")
|
raise NotImplementedError("no such dataset")
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
ema_tracker: Optional[EMACallable] = None,
|
ema_tracker: Optional[EMACallable] = None,
|
||||||
optimizer: OptimizerCallable = None,
|
optimizer: OptimizerCallable = None,
|
||||||
lr_scheduler: LRSchedulerCallable = None,
|
lr_scheduler: LRSchedulerCallable = None,
|
||||||
|
compile: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
|
|||||||
# self.model_loader = ModelLoader()
|
# self.model_loader = ModelLoader()
|
||||||
|
|
||||||
self._strict_loading = False
|
self._strict_loading = False
|
||||||
|
self._compile = compile
|
||||||
|
|
||||||
def configure_model(self) -> None:
|
def configure_model(self) -> None:
|
||||||
self.trainer.strategy.barrier()
|
self.trainer.strategy.barrier()
|
||||||
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
|
|||||||
no_grad(self.diffusion_sampler)
|
no_grad(self.diffusion_sampler)
|
||||||
no_grad(self.ema_denoiser)
|
no_grad(self.ema_denoiser)
|
||||||
|
|
||||||
|
# add compile to speed up
|
||||||
|
if self._compile:
|
||||||
|
self.denoiser = torch.compile(self.denoiser)
|
||||||
|
self.ema_denoiser = torch.compile(self.ema_denoiser)
|
||||||
|
|
||||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||||
return [ema_tracker]
|
return [ema_tracker]
|
||||||
|
|||||||
@@ -251,11 +251,11 @@ class DiT(nn.Module):
|
|||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
self.precompute_pos = dict()
|
self.precompute_pos = dict()
|
||||||
|
|
||||||
def fetch_pos(self, height, width, device, dtype):
|
def fetch_pos(self, height, width, device):
|
||||||
if (height, width) in self.precompute_pos:
|
if (height, width) in self.precompute_pos:
|
||||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
return self.precompute_pos[(height, width)].to(device)
|
||||||
else:
|
else:
|
||||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||||
self.precompute_pos[(height, width)] = pos
|
self.precompute_pos[(height, width)] = pos
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ class DiT(nn.Module):
|
|||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||||
x = self.x_embedder(x)
|
x = self.x_embedder(x)
|
||||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||||
y = self.y_embedder(y).view(B, 1, C)
|
y = self.y_embedder(y).view(B, 1, C)
|
||||||
|
|||||||
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user