import sys, os
if not os.getcwd().split('/')[-1] == 'zenith':
os.chdir('/var/home/ncmir-lab/madany/zenith')
if __name__ == '__main__' and os.getcwd().split('/')[-1] == 'ai': os.chdir('..')
print(os.getcwd())
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
import functools
import time
from ai.blocks import UnetSkipConnectionBlock#, ResnetBlock
from ai.nets.Resnet import Resnet
from ai.nets.ResUnetPlus import ResUnetPlusPlus
from ai.nets.GroupUnet import GroupUnet
from ai.nets.RAUNet import RAUNet
from ai.nets.CDUnet import CDUnet
from ai.nets.ResUnet2 import AniResUnet
from ai.nets.ResBUnet import AniResBUnet
import copy
sys.path.append(os.getcwd()+'/repos/unet3d')
#%%
def Backends():
nets = {
'Unet': Unet,
'Resnet': Resnet,
'UBnet': GroupUnet,
'ResUnet': AniResUnet,
'RBUnet': AniResBUnet,
}
# 'AttnGate': AttnGate
# 'ResUnet': ResUnetPlusPlus,
return nets
def TestNetwork(model, input_nc = 1, output_nc = 8,
tensor_size = (5,256,256), n_tests=8, batch = 1,
cuda = True):
model_nm = model
model = model(input_nc = input_nc, output_nc = output_nc, tensor_size = tensor_size)
model = torch.nn.DataParallel(model)
#print(model)
ts = (batch, input_nc, *tensor_size)
x = torch.ones(ts)
if cuda:
x = torch.Tensor(x).cuda()
model = model.to('cuda')
else:
x = x.cpu()
model.to('cpu')
start_time = time.time()
for i in range(n_tests): y = model(x)
if not str(y[0,0,].shape) == str(x[0,0,].shape) or not int(y.shape[1]) == output_nc:
raise ValueError('Expected output tensor of shape %s but got %s' % (str((1,output_nc,*tensor_size)), str(y.shape)))
print('Model %s passed test. Input Channels: %g | Output Channels: %g | Tensor size: %s | Batches: %g | Time: %s' % (str(model_nm),
input_nc, output_nc, str(tensor_size), batch,
str(time.time()-start_time)))
del(x,y,model)
#%%
def Unet(input_nc, output_nc, features = 64, tensor_size = None):
if tensor_size[0] == 3:
zp = 2
else:
zp = 4
from pytorch3dunet.unet3d.model import UNet3D
unet = UNet3D(input_nc,output_nc, f_maps= features, num_levels = 3,
num_groups = 1 if input_nc == 1 else 2,
conv_kernel_size=3)
beg = [nn.ReplicationPad3d((0,0,0,0,1,1))]
beg += [nn.BatchNorm3d(features),nn.ConvTranspose3d(features,features,kernel_size=3,stride=(1,1,1),padding=0,output_padding=0), nn.ReLU(True)]
unet.encoders.insert(1,nn.Sequential(*beg))
fin = [nn.BatchNorm3d(features), nn.ConvTranspose3d(features,features,kernel_size =(1,3,3), stride =1, padding = 1, output_padding =0), nn.ReLU(True)]
fin += [nn.BatchNorm3d(features),nn.Conv3d(features,features,kernel_size=3,stride=1, padding=0), nn.ReLU(True)]
fin += [nn.Conv3d(features,output_nc,kernel_size=1,stride=1)]
unet.final_conv = nn.Sequential(*fin)
unet.final_activation = None
return unet