Bladeren bron

SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.

Zhankun Tang 6 jaren geleden
bovenliggende
commit
36267b6f7c
100 gewijzigde bestanden met toevoegingen van 4286 en 973 verwijderingen
  1. 77 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/base/ubuntu-16.04/Dockerfile.gpu.pytorch_latest
  2. 30 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/build-all.sh
  3. 354 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/cifar10_tutorial.py
  4. 21 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.pytorch_latest
  5. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1
  6. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1
  7. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/build-all.sh
  8. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1
  9. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1
  10. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md
  11. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py
  12. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py
  13. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py
  14. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py
  15. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py
  16. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py
  17. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/Dockerfile.gpu
  18. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/run_container.sh
  19. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/shiro.ini
  20. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/zeppelin-site.xml
  21. 1 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Cli.java
  22. 2 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliConstants.java
  23. 1 1
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliUtils.java
  24. 24 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Command.java
  25. 6 6
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/ShowJobCli.java
  26. 24 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ConfigType.java
  27. 134 9
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ParametersHolder.java
  28. 120 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/PyTorchRunJobParameters.java
  29. 146 197
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/RunJobParameters.java
  30. 215 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/TensorFlowRunJobParameters.java
  31. 20 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/package-info.java
  32. 9 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/yaml/Spec.java
  33. 59 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/Framework.java
  34. 81 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RoleParameters.java
  35. 123 106
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RunJobCli.java
  36. 19 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/package-info.java
  37. 54 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/PyTorchRole.java
  38. 25 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Role.java
  39. 58 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Runtime.java
  40. 11 2
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TensorFlowRole.java
  41. 5 5
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/common/JobSubmitter.java
  42. 7 0
      hadoop-submarine/hadoop-submarine-core/src/site/markdown/QuickStart.md
  43. 0 226
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java
  44. 6 5
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/YamlConfigTestUtils.java
  45. 129 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommon.java
  46. 252 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommonYaml.java
  47. 192 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingParameterized.java
  48. 209 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorch.java
  49. 225 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorchYaml.java
  50. 170 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlow.java
  51. 45 159
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYaml.java
  52. 8 9
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYamlStandalone.java
  53. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/empty-framework.yaml
  54. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/invalid-framework.yaml
  55. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-configs.yaml
  56. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-framework.yaml
  57. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/some-sections-missing.yaml
  58. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/test-false-values.yaml
  59. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-indentation.yaml
  60. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-property-name.yaml
  61. 51 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/envs-are-missing.yaml
  62. 56 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-ps-section.yaml
  63. 57 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-tensorboard-section.yaml
  64. 53 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/security-principal-is-missing.yaml
  65. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config-with-overrides.yaml
  66. 54 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config.yaml
  67. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/envs-are-missing.yaml
  68. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/security-principal-is-missing.yaml
  69. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/tensorboard-dockerimage-is-missing.yaml
  70. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config-with-overrides.yaml
  71. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config.yaml
  72. 17 5
      hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyJobSubmitter.java
  73. 2 2
      hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyUtils.java
  74. 4 0
      hadoop-submarine/hadoop-submarine-tony-runtime/src/site/markdown/QuickStart.md
  75. 19 7
      hadoop-submarine/hadoop-submarine-tony-runtime/src/test/java/TestTonyUtils.java
  76. 49 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java
  77. 167 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractServiceSpec.java
  78. 10 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java
  79. 20 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java
  80. 1 1
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java
  81. 71 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/WorkerComponentFactory.java
  82. 62 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java
  83. 4 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java
  84. 5 42
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java
  85. 4 3
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java
  86. 61 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/PyTorchLaunchCommandFactory.java
  87. 70 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TensorFlowLaunchCommandFactory.java
  88. 68 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/PyTorchServiceSpec.java
  89. 87 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/PyTorchWorkerLaunchCommand.java
  90. 19 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/package-info.java
  91. 47 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/PyTorchWorkerComponent.java
  92. 20 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/package-info.java
  93. 20 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/package-info.java
  94. 5 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java
  95. 22 125
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java
  96. 4 4
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java
  97. 11 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
  98. 5 4
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java
  99. 5 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java
  100. 16 12
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java

+ 77 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/base/ubuntu-16.04/Dockerfile.gpu.pytorch_latest

@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
+ARG PYTHON_VERSION=3.6
+RUN apt-get update && apt-get install -y --no-install-recommends \
+         build-essential \
+         cmake \
+         git \
+         curl \
+         vim \
+         ca-certificates \
+         libjpeg-dev \
+         libpng-dev \
+         wget &&\
+     rm -rf /var/lib/apt/lists/*
+
+
+RUN curl -o ~/miniconda.sh -O  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \
+     chmod +x ~/miniconda.sh && \
+     ~/miniconda.sh -b -p /opt/conda && \
+     rm ~/miniconda.sh && \
+     /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
+     /opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
+     /opt/conda/bin/conda clean -ya
+ENV PATH /opt/conda/bin:$PATH
+RUN pip install ninja
+# This must be done before pip so that requirements.txt is available
+WORKDIR /opt/pytorch
+RUN git clone https://github.com/pytorch/pytorch.git
+WORKDIR pytorch
+RUN git submodule update --init
+RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
+    CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
+    pip install -v .
+
+WORKDIR /opt/pytorch
+RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
+
+WORKDIR /
+# Install Hadoop
+ENV HADOOP_VERSION="3.1.2"
+RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
+RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
+RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
+RUN rm hadoop-${HADOOP_VERSION}.tar.gz
+
+ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
+RUN echo "$LOG_TAG Install java8" && \
+    apt-get update && \
+    apt-get install -y --no-install-recommends openjdk-8-jdk && \
+    apt-get clean && rm -rf /var/lib/apt/lists/*
+
+RUN echo "Install python related packages" && \
+    pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
+    python -m ipykernel.kernelspec
+
+# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
+RUN apt-get update && apt-get install -y --no-install-recommends locales && \
+    apt-get clean && rm -rf /var/lib/apt/lists/*
+RUN locale-gen en_US.UTF-8
+
+
+WORKDIR /workspace
+RUN chmod -R a+w /workspace

+ 30 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/build-all.sh

@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+echo "Building base images"
+
+set -e
+
+cd base/ubuntu-16.04
+
+docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
+
+echo "Finished building base images"
+
+cd ../../with-cifar10-models/ubuntu-16.04
+
+docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1

+ 354 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/cifar10_tutorial.py

@@ -0,0 +1,354 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#    http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# -*- coding: utf-8 -*-
+"""
+Training a Classifier
+=====================
+
+This is it. You have seen how to define neural networks, compute loss and make
+updates to the weights of the network.
+
+Now you might be thinking,
+
+What about data?
+----------------
+
+Generally, when you have to deal with image, text, audio or video data,
+you can use standard python packages that load data into a numpy array.
+Then you can convert this array into a ``torch.*Tensor``.
+
+-  For images, packages such as Pillow, OpenCV are useful
+-  For audio, packages such as scipy and librosa
+-  For text, either raw Python or Cython based loading, or NLTK and
+   SpaCy are useful
+
+Specifically for vision, we have created a package called
+``torchvision``, that has data loaders for common datasets such as
+Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
+``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
+
+This provides a huge convenience and avoids writing boilerplate code.
+
+For this tutorial, we will use the CIFAR10 dataset.
+It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
+‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of
+size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
+
+.. figure:: /_static/img/cifar10.png
+   :alt: cifar10
+
+   cifar10
+
+
+Training an image classifier
+----------------------------
+
+We will do the following steps in order:
+
+1. Load and normalizing the CIFAR10 training and test datasets using
+   ``torchvision``
+2. Define a Convolutional Neural Network
+3. Define a loss function
+4. Train the network on the training data
+5. Test the network on the test data
+
+1. Loading and normalizing CIFAR10
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Using ``torchvision``, it’s extremely easy to load CIFAR10.
+"""
+import torch
+import torchvision
+import torchvision.transforms as transforms
+
+########################################################################
+# The output of torchvision datasets are PILImage images of range [0, 1].
+# We transform them to Tensors of normalized range [-1, 1].
+
+transform = transforms.Compose(
+  [transforms.ToTensor(),
+   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
+                                        download=True, transform=transform)
+trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
+                                          shuffle=True, num_workers=2)
+
+testset = torchvision.datasets.CIFAR10(root='./data', train=False,
+                                       download=True, transform=transform)
+testloader = torch.utils.data.DataLoader(testset, batch_size=4,
+                                         shuffle=False, num_workers=2)
+
+classes = ('plane', 'car', 'bird', 'cat',
+           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
+
+########################################################################
+# Let us show some of the training images, for fun.
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+# functions to show an image
+
+
+def imshow(img):
+  img = img / 2 + 0.5  # unnormalize
+  npimg = img.numpy()
+  plt.imshow(np.transpose(npimg, (1, 2, 0)))
+  plt.show()
+
+
+# get some random training images
+dataiter = iter(trainloader)
+images, labels = dataiter.next()
+
+# show images
+imshow(torchvision.utils.make_grid(images))
+# print labels
+print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
+
+########################################################################
+# 2. Define a Convolutional Neural Network
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# Copy the neural network from the Neural Networks section before and modify it to
+# take 3-channel images (instead of 1-channel images as it was defined).
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Net(nn.Module):
+  def __init__(self):
+    super(Net, self).__init__()
+    self.conv1 = nn.Conv2d(3, 6, 5)
+    self.pool = nn.MaxPool2d(2, 2)
+    self.conv2 = nn.Conv2d(6, 16, 5)
+    self.fc1 = nn.Linear(16 * 5 * 5, 120)
+    self.fc2 = nn.Linear(120, 84)
+    self.fc3 = nn.Linear(84, 10)
+
+  def forward(self, x):
+    x = self.pool(F.relu(self.conv1(x)))
+    x = self.pool(F.relu(self.conv2(x)))
+    x = x.view(-1, 16 * 5 * 5)
+    x = F.relu(self.fc1(x))
+    x = F.relu(self.fc2(x))
+    x = self.fc3(x)
+    return x
+
+
+net = Net()
+
+########################################################################
+# 3. Define a Loss function and optimizer
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# Let's use a Classification Cross-Entropy loss and SGD with momentum.
+
+import torch.optim as optim
+
+criterion = nn.CrossEntropyLoss()
+optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
+
+########################################################################
+# 4. Train the network
+# ^^^^^^^^^^^^^^^^^^^^
+#
+# This is when things start to get interesting.
+# We simply have to loop over our data iterator, and feed the inputs to the
+# network and optimize.
+
+for epoch in range(2):  # loop over the dataset multiple times
+
+  running_loss = 0.0
+  for i, data in enumerate(trainloader, 0):
+    # get the inputs
+    inputs, labels = data
+
+    # zero the parameter gradients
+    optimizer.zero_grad()
+
+    # forward + backward + optimize
+    outputs = net(inputs)
+    loss = criterion(outputs, labels)
+    loss.backward()
+    optimizer.step()
+
+    # print statistics
+    running_loss += loss.item()
+    if i % 2000 == 1999:  # print every 2000 mini-batches
+      print('[%d, %5d] loss: %.3f' %
+            (epoch + 1, i + 1, running_loss / 2000))
+      running_loss = 0.0
+
+print('Finished Training')
+
+########################################################################
+# 5. Test the network on the test data
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+#
+# We have trained the network for 2 passes over the training dataset.
+# But we need to check if the network has learnt anything at all.
+#
+# We will check this by predicting the class label that the neural network
+# outputs, and checking it against the ground-truth. If the prediction is
+# correct, we add the sample to the list of correct predictions.
+#
+# Okay, first step. Let us display an image from the test set to get familiar.
+
+dataiter = iter(testloader)
+images, labels = dataiter.next()
+
+# print images
+imshow(torchvision.utils.make_grid(images))
+print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
+
+########################################################################
+# Okay, now let us see what the neural network thinks these examples above are:
+
+outputs = net(images)
+
+########################################################################
+# The outputs are energies for the 10 classes.
+# The higher the energy for a class, the more the network
+# thinks that the image is of the particular class.
+# So, let's get the index of the highest energy:
+_, predicted = torch.max(outputs, 1)
+
+print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
+                              for j in range(4)))
+
+########################################################################
+# The results seem pretty good.
+#
+# Let us look at how the network performs on the whole dataset.
+
+correct = 0
+total = 0
+with torch.no_grad():
+  for data in testloader:
+    images, labels = data
+    outputs = net(images)
+    _, predicted = torch.max(outputs.data, 1)
+    total += labels.size(0)
+    correct += (predicted == labels).sum().item()
+
+print('Accuracy of the network on the 10000 test images: %d %%' % (
+        100 * correct / total))
+
+########################################################################
+# That looks waaay better than chance, which is 10% accuracy (randomly picking
+# a class out of 10 classes).
+# Seems like the network learnt something.
+#
+# Hmmm, what are the classes that performed well, and the classes that did
+# not perform well:
+
+class_correct = list(0. for i in range(10))
+class_total = list(0. for i in range(10))
+with torch.no_grad():
+  for data in testloader:
+    images, labels = data
+    outputs = net(images)
+    _, predicted = torch.max(outputs, 1)
+    c = (predicted == labels).squeeze()
+    for i in range(4):
+      label = labels[i]
+      class_correct[label] += c[i].item()
+      class_total[label] += 1
+
+for i in range(10):
+  print('Accuracy of %5s : %2d %%' % (
+    classes[i], 100 * class_correct[i] / class_total[i]))
+
+########################################################################
+# Okay, so what next?
+#
+# How do we run these neural networks on the GPU?
+#
+# Training on GPU
+# ----------------
+# Just like how you transfer a Tensor onto the GPU, you transfer the neural
+# net onto the GPU.
+#
+# Let's first define our device as the first visible cuda device if we have
+# CUDA available:
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+# Assuming that we are on a CUDA machine, this should print a CUDA device:
+
+print(device)
+
+########################################################################
+# The rest of this section assumes that ``device`` is a CUDA device.
+#
+# Then these methods will recursively go over all modules and convert their
+# parameters and buffers to CUDA tensors:
+#
+# .. code:: python
+#
+#     net.to(device)
+#
+#
+# Remember that you will have to send the inputs and targets at every step
+# to the GPU too:
+#
+# .. code:: python
+#
+#         inputs, labels = inputs.to(device), labels.to(device)
+#
+# Why dont I notice MASSIVE speedup compared to CPU? Because your network
+# is realllly small.
+#
+# **Exercise:** Try increasing the width of your network (argument 2 of
+# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d`` –
+# they need to be the same number), see what kind of speedup you get.
+#
+# **Goals achieved**:
+#
+# - Understanding PyTorch's Tensor library and neural networks at a high level.
+# - Train a small neural network to classify images
+#
+# Training on multiple GPUs
+# -------------------------
+# If you want to see even more MASSIVE speedup using all of your GPUs,
+# please check out :doc:`data_parallel_tutorial`.
+#
+# Where do I go next?
+# -------------------
+#
+# -  :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
+# -  `Train a state-of-the-art ResNet network on imagenet`_
+# -  `Train a face generator using Generative Adversarial Networks`_
+# -  `Train a word-level language model using Recurrent LSTM networks`_
+# -  `More examples`_
+# -  `More tutorials`_
+# -  `Discuss PyTorch on the Forums`_
+# -  `Chat with other users on Slack`_
+#
+# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
+# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
+# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
+# .. _More examples: https://github.com/pytorch/examples
+# .. _More tutorials: https://github.com/pytorch/tutorials
+# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
+# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
+
+# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
+del dataiter
+# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%

+ 21 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.pytorch_latest

@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM pytorch-latest-gpu-base:0.0.1
+
+RUN mkdir -p /test/data
+RUN chmod -R 777 /test
+ADD cifar10_tutorial.py /test/cifar10_tutorial.py

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/build-all.sh → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/build-all.sh


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/Dockerfile.gpu → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/Dockerfile.gpu


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/run_container.sh → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/run_container.sh


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/shiro.ini → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/shiro.ini


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/zeppelin-site.xml → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/zeppelin-site.xml


+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Cli.java

@@ -15,6 +15,7 @@
 package org.apache.hadoop.yarn.submarine.client.cli;
 package org.apache.hadoop.yarn.submarine.client.cli;
 
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;

+ 2 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliConstants.java

@@ -59,4 +59,6 @@ public class CliConstants {
   public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
   public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
   public static final String YAML_CONFIG = "f";
   public static final String YAML_CONFIG = "f";
   public static final String INSECURE_CLUSTER = "insecure";
   public static final String INSECURE_CLUSTER = "insecure";
+
+  public static final String FRAMEWORK = "framework";
 }
 }

+ 1 - 1
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliUtils.java

@@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
 
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
 import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.slf4j.Logger;
 import org.slf4j.Logger;

+ 24 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Command.java

@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli;
+
+/**
+ * Represents a Submarine command.
+ */
+public enum Command {
+  RUN_JOB, SHOW_JOB
+}

+ 6 - 6
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/ShowJobCli.java

@@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
   private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
   private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
 
 
   private Options options;
   private Options options;
-  private ShowJobParameters parameters = new ShowJobParameters();
+  private ParametersHolder parametersHolder;
 
 
   public ShowJobCli(ClientContext cliContext) {
   public ShowJobCli(ClientContext cliContext) {
     super(cliContext);
     super(cliContext);
@@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
     CommandLine cli;
     CommandLine cli;
     try {
     try {
       cli = parser.parse(options, args);
       cli = parser.parse(options, args);
-      ParametersHolder parametersHolder = ParametersHolder
-          .createWithCmdLine(cli);
-      parameters.updateParameters(parametersHolder, clientContext);
+      parametersHolder = ParametersHolder
+          .createWithCmdLine(cli, Command.SHOW_JOB);
+      parametersHolder.updateParameters(clientContext);
     } catch (ParseException e) {
     } catch (ParseException e) {
       printUsages();
       printUsages();
     }
     }
@@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
 
 
     Map<String, String> jobInfo = null;
     Map<String, String> jobInfo = null;
     try {
     try {
-      jobInfo = storage.getJobInfoByName(parameters.getName());
+      jobInfo = storage.getJobInfoByName(getParameters().getName());
     } catch (IOException e) {
     } catch (IOException e) {
       LOG.error("Failed to retrieve job info", e);
       LOG.error("Failed to retrieve job info", e);
       throw e;
       throw e;
@@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
 
 
   @VisibleForTesting
   @VisibleForTesting
   public ShowJobParameters getParameters() {
   public ShowJobParameters getParameters() {
-    return parameters;
+    return (ShowJobParameters) parametersHolder.getParameters();
   }
   }
 
 
   @Override
   @Override

+ 24 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ConfigType.java

@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param;
+
+/**
+ * Represents the source of configuration.
+ */
+public enum ConfigType {
+  YAML, CLI
+}

+ 134 - 9
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ParametersHolder.java

@@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Maps;
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.ParseException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.Command;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
@@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 
 
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
+
 /**
 /**
  * This class acts as a wrapper of {@code CommandLine} values along with
  * This class acts as a wrapper of {@code CommandLine} values along with
  * YAML configuration values.
  * YAML configuration values.
@@ -52,17 +63,110 @@ public final class ParametersHolder {
   private static final Logger LOG =
   private static final Logger LOG =
       LoggerFactory.getLogger(ParametersHolder.class);
       LoggerFactory.getLogger(ParametersHolder.class);
 
 
+  public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
+      "TensorFlow and PyTorch are the only supported frameworks for now!";
+  public static final String SUPPORTED_COMMANDS_MESSAGE =
+      "'Show job' and 'run job' are the only supported commands for now!";
+
+
+
   private final CommandLine parsedCommandLine;
   private final CommandLine parsedCommandLine;
   private final Map<String, String> yamlStringConfigs;
   private final Map<String, String> yamlStringConfigs;
   private final Map<String, List<String>> yamlListConfigs;
   private final Map<String, List<String>> yamlListConfigs;
-  private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of(
+  private final ConfigType configType;
+  private Command command;
+  private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
       CliConstants.VERBOSE);
       CliConstants.VERBOSE);
+  private final Framework framework;
+  private final BaseParameters parameters;
 
 
   private ParametersHolder(CommandLine parsedCommandLine,
   private ParametersHolder(CommandLine parsedCommandLine,
-      YamlConfigFile yamlConfig) {
+      YamlConfigFile yamlConfig, ConfigType configType, Command command)
+      throws ParseException, YarnException {
     this.parsedCommandLine = parsedCommandLine;
     this.parsedCommandLine = parsedCommandLine;
     this.yamlStringConfigs = initStringConfigValues(yamlConfig);
     this.yamlStringConfigs = initStringConfigValues(yamlConfig);
     this.yamlListConfigs = initListConfigValues(yamlConfig);
     this.yamlListConfigs = initListConfigValues(yamlConfig);
+    this.configType = configType;
+    this.command = command;
+    this.framework = determineFrameworkType();
+    this.ensureOnlyValidSectionsAreDefined(yamlConfig);
+    this.parameters = createParameters();
+  }
+
+  private BaseParameters createParameters() {
+    if (command == Command.RUN_JOB) {
+      if (framework == Framework.TENSORFLOW) {
+        return new TensorFlowRunJobParameters();
+      } else if (framework == Framework.PYTORCH) {
+        return new PyTorchRunJobParameters();
+      } else {
+        throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
+      }
+    } else if (command == Command.SHOW_JOB) {
+      return new ShowJobParameters();
+    } else {
+      throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
+    }
+  }
+
+  private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
+    if (isCommandRunJob() && isFrameworkPyTorch() &&
+        isPsSectionDefined(yamlConfig)) {
+      throw new YamlParseException(
+          "PS section should not be defined when PyTorch " +
+              "is the selected framework!");
+    }
+
+    if (isCommandRunJob() && isFrameworkPyTorch() &&
+        isTensorboardSectionDefined(yamlConfig)) {
+      throw new YamlParseException(
+          "TensorBoard section should not be defined when PyTorch " +
+              "is the selected framework!");
+    }
+  }
+
+  private boolean isCommandRunJob() {
+    return command == Command.RUN_JOB;
+  }
+
+  private boolean isFrameworkPyTorch() {
+    return framework == Framework.PYTORCH;
+  }
+
+  private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
+    return yamlConfig != null &&
+        yamlConfig.getRoles() != null &&
+        yamlConfig.getRoles().getPs() != null;
+  }
+
+  private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
+    return yamlConfig != null &&
+        yamlConfig.getTensorBoard() != null;
+  }
+
+  private Framework determineFrameworkType()
+      throws ParseException, YarnException {
+    if (!isCommandRunJob()) {
+      return null;
+    }
+    String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
+    if (frameworkStr == null) {
+      LOG.info("Framework is not defined in config, falling back to " +
+          "TensorFlow as a default.");
+      return Framework.TENSORFLOW;
+    }
+    Framework framework = Framework.parseByValue(frameworkStr);
+    if (framework == null) {
+      if (getConfigType() == ConfigType.CLI) {
+        throw new ParseException("Failed to parse Framework type! "
+            + "Valid values are: " + Framework.getValues());
+      } else {
+        throw new YamlParseException(YAML_PARSE_FAILED +
+            ", framework should is defined, but it has an invalid value! " +
+            "Valid values are: " + Framework.getValues());
+      }
+    }
+    return framework;
   }
   }
 
 
   /**
   /**
@@ -108,6 +212,8 @@ public final class ParametersHolder {
   private void initGenericConfigs(YamlConfigFile yamlConfig,
   private void initGenericConfigs(YamlConfigFile yamlConfig,
       Map<String, String> yamlConfigs) {
       Map<String, String> yamlConfigs) {
     yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
     yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
+    yamlConfigs.put(CliConstants.FRAMEWORK,
+        yamlConfig.getSpec().getFramework());
 
 
     Configs configs = yamlConfig.getConfigs();
     Configs configs = yamlConfig.getConfigs();
     yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
     yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
@@ -178,13 +284,15 @@ public final class ParametersHolder {
         .collect(Collectors.toList());
         .collect(Collectors.toList());
   }
   }
 
 
-  public static ParametersHolder createWithCmdLine(CommandLine cli) {
-    return new ParametersHolder(cli, null);
+  public static ParametersHolder createWithCmdLine(CommandLine cli,
+      Command command) throws ParseException, YarnException {
+    return new ParametersHolder(cli, null, ConfigType.CLI, command);
   }
   }
 
 
   public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
   public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
-      YamlConfigFile yamlConfig) {
-    return new ParametersHolder(cli, yamlConfig);
+      YamlConfigFile yamlConfig, Command command) throws ParseException,
+      YarnException {
+    return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
   }
   }
 
 
   /**
   /**
@@ -193,7 +301,7 @@ public final class ParametersHolder {
    * @param option Name of the config.
    * @param option Name of the config.
    * @return The value of the config
    * @return The value of the config
    */
    */
-  String getOptionValue(String option) throws YarnException {
+  public String getOptionValue(String option) throws YarnException {
     ensureConfigIsDefinedOnce(option, true);
     ensureConfigIsDefinedOnce(option, true);
     if (onlyDefinedWithCliArgs.contains(option) ||
     if (onlyDefinedWithCliArgs.contains(option) ||
         parsedCommandLine.hasOption(option)) {
         parsedCommandLine.hasOption(option)) {
@@ -208,7 +316,7 @@ public final class ParametersHolder {
    * @param option Name of the config.
    * @param option Name of the config.
    * @return The values of the config
    * @return The values of the config
    */
    */
-  List<String> getOptionValues(String option) throws YarnException {
+  public List<String> getOptionValues(String option) throws YarnException {
     ensureConfigIsDefinedOnce(option, false);
     ensureConfigIsDefinedOnce(option, false);
     if (onlyDefinedWithCliArgs.contains(option) ||
     if (onlyDefinedWithCliArgs.contains(option) ||
         parsedCommandLine.hasOption(option)) {
         parsedCommandLine.hasOption(option)) {
@@ -285,7 +393,7 @@ public final class ParametersHolder {
    * @return true, if the option is found in the CLI args or in the YAML config,
    * @return true, if the option is found in the CLI args or in the YAML config,
    * false otherwise.
    * false otherwise.
    */
    */
-  boolean hasOption(String option) {
+  public boolean hasOption(String option) {
     if (onlyDefinedWithCliArgs.contains(option)) {
     if (onlyDefinedWithCliArgs.contains(option)) {
       boolean value = parsedCommandLine.hasOption(option);
       boolean value = parsedCommandLine.hasOption(option);
       if (LOG.isDebugEnabled()) {
       if (LOG.isDebugEnabled()) {
@@ -312,4 +420,21 @@ public final class ParametersHolder {
         "from YAML configuration.", result, option);
         "from YAML configuration.", result, option);
     return result;
     return result;
   }
   }
+
+  public ConfigType getConfigType() {
+    return configType;
+  }
+
+  public Framework getFramework() {
+    return framework;
+  }
+
+  public void updateParameters(ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    parameters.updateParameters(this, clientContext);
+  }
+
+  public BaseParameters getParameters() {
+    return parameters;
+  }
 }
 }

+ 120 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/PyTorchRunJobParameters.java

@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Parameters for PyTorch job.
+ */
+public class PyTorchRunJobParameters extends RunJobParameters {
+
+  private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
+      "cannot be defined for PyTorch jobs!";
+
+  @Override
+  public void updateParameters(ParametersHolder parametersHolder,
+      ClientContext clientContext)
+      throws ParseException, IOException, YarnException {
+    checkArguments(parametersHolder);
+
+    super.updateParameters(parametersHolder, clientContext);
+
+    String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
+    this.workerParameters =
+        getWorkerParameters(clientContext, parametersHolder, input);
+    this.distributed = determineIfDistributed(workerParameters.getReplicas());
+    executePostOperations(clientContext);
+  }
+
+  private void checkArguments(ParametersHolder parametersHolder)
+      throws YarnException, ParseException {
+    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.N_PS));
+    } else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_RES));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_DOCKER_IMAGE));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_LAUNCH_CMD));
+    } else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD_RESOURCES));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD_DOCKER_IMAGE));
+    }
+  }
+
+  private String getParamCannotBeDefinedErrorMessage(String cliName) {
+    return String.format(
+        "Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
+  }
+
+  @Override
+  void executePostOperations(ClientContext clientContext) throws IOException {
+    // Set default job dir / saved model dir, etc.
+    setDefaultDirs(clientContext);
+    replacePatternsInParameters(clientContext);
+  }
+
+  private void replacePatternsInParameters(ClientContext clientContext)
+      throws IOException {
+    if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
+      String afterReplace =
+          CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
+              clientContext.getRemoteDirectoryManager());
+      setWorkerLaunchCmd(afterReplace);
+    }
+  }
+
+  @Override
+  public List<String> getLaunchCommands() {
+    return Lists.newArrayList(getWorkerLaunchCmd());
+  }
+
+  /**
+   * We only support non-distributed PyTorch integration for now.
+   * @param nWorkers
+   * @return
+   */
+  private boolean determineIfDistributed(int nWorkers) {
+    return false;
+  }
+}

+ 146 - 197
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/RunJobParameters.java

@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 /**
 /**
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -12,7 +28,7 @@
  * limitations under the License. See accompanying LICENSE file.
  * limitations under the License. See accompanying LICENSE file.
  */
  */
 
 
-package org.apache.hadoop.yarn.submarine.client.cli.param;
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
 
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.CaseFormat;
 import com.google.common.base.CaseFormat;
@@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
 import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
 import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.yaml.snakeyaml.introspector.Property;
 import org.yaml.snakeyaml.introspector.Property;
 import org.yaml.snakeyaml.introspector.PropertyUtils;
 import org.yaml.snakeyaml.introspector.PropertyUtils;
@@ -34,27 +57,15 @@ import java.util.List;
 /**
 /**
  * Parameters used to run a job
  * Parameters used to run a job
  */
  */
-public class RunJobParameters extends RunParameters {
+public abstract class RunJobParameters extends RunParameters {
   private String input;
   private String input;
   private String checkpointPath;
   private String checkpointPath;
 
 
-  private int numWorkers;
-  private int numPS;
-  private Resource workerResource;
-  private Resource psResource;
-  private boolean tensorboardEnabled;
-  private Resource tensorboardResource;
-  private String tensorboardDockerImage;
-  private String workerLaunchCmd;
-  private String psLaunchCmd;
   private List<Quicklink> quicklinks = new ArrayList<>();
   private List<Quicklink> quicklinks = new ArrayList<>();
   private List<Localization> localizations = new ArrayList<>();
   private List<Localization> localizations = new ArrayList<>();
 
 
-  private String psDockerImage = null;
-  private String workerDockerImage = null;
-
   private boolean waitJobFinish = false;
   private boolean waitJobFinish = false;
-  private boolean distributed = false;
+  protected boolean distributed = false;
 
 
   private boolean securityDisabled = false;
   private boolean securityDisabled = false;
   private String keytab;
   private String keytab;
@@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
   private boolean distributeKeytab = false;
   private boolean distributeKeytab = false;
   private List<String> confPairs = new ArrayList<>();
   private List<String> confPairs = new ArrayList<>();
 
 
+  RoleParameters workerParameters =
+      RoleParameters.createEmpty(TensorFlowRole.WORKER);
+
   @Override
   @Override
   public void updateParameters(ParametersHolder parametersHolder,
   public void updateParameters(ParametersHolder parametersHolder,
       ClientContext clientContext)
       ClientContext clientContext)
@@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
     String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
     String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
     String jobDir = parametersHolder.getOptionValue(
     String jobDir = parametersHolder.getOptionValue(
         CliConstants.CHECKPOINT_PATH);
         CliConstants.CHECKPOINT_PATH);
-    int nWorkers = 1;
-    if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
-      nWorkers = Integer.parseInt(
-          parametersHolder.getOptionValue(CliConstants.N_WORKERS));
-      // Only check null value.
-      // Training job shouldn't ignore INPUT_PATH option
-      // But if nWorkers is 0, INPUT_PATH can be ignored because
-      // user can only run Tensorboard
-      if (null == input && 0 != nWorkers) {
-        throw new ParseException("\"--" + CliConstants.INPUT_PATH +
-            "\" is absent");
-      }
-    }
-
-    int nPS = 0;
-    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
-      nPS = Integer.parseInt(
-          parametersHolder.getOptionValue(CliConstants.N_PS));
-    }
-
-    // Check #workers and #ps.
-    // When distributed training is required
-    if (nWorkers >= 2 && nPS > 0) {
-      distributed = true;
-    } else if (nWorkers <= 1 && nPS > 0) {
-      throw new ParseException("Only specified one worker but non-zero PS, "
-          + "please double check.");
-    }
 
 
     if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
     if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
       setSecurityDisabled(true);
       setSecurityDisabled(true);
@@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
         CliConstants.PRINCIPAL);
         CliConstants.PRINCIPAL);
     CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
     CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
 
 
-    workerResource = null;
-    if (nWorkers > 0) {
-      String workerResourceStr = parametersHolder.getOptionValue(
-          CliConstants.WORKER_RES);
-      if (workerResourceStr == null) {
-        throw new ParseException(
-            "--" + CliConstants.WORKER_RES + " is absent.");
-      }
-      workerResource = ResourceUtils.createResourceFromString(
-          workerResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-    }
-
-    Resource psResource = null;
-    if (nPS > 0) {
-      String psResourceStr = parametersHolder.getOptionValue(
-          CliConstants.PS_RES);
-      if (psResourceStr == null) {
-        throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
-      }
-      psResource = ResourceUtils.createResourceFromString(psResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-    }
-
-    boolean tensorboard = false;
-    if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
-      tensorboard = true;
-      String tensorboardResourceStr = parametersHolder.getOptionValue(
-          CliConstants.TENSORBOARD_RESOURCES);
-      if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
-        tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
-      }
-      tensorboardResource = ResourceUtils.createResourceFromString(
-          tensorboardResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-      tensorboardDockerImage = parametersHolder.getOptionValue(
-          CliConstants.TENSORBOARD_DOCKER_IMAGE);
-      this.setTensorboardResource(tensorboardResource);
-    }
-
     if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
     if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
       this.waitJobFinish = true;
       this.waitJobFinish = true;
     }
     }
@@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
       }
       }
     }
     }
 
 
-    psDockerImage = parametersHolder.getOptionValue(
-        CliConstants.PS_DOCKER_IMAGE);
-    workerDockerImage = parametersHolder.getOptionValue(
-        CliConstants.WORKER_DOCKER_IMAGE);
-
-    String workerLaunchCmd = parametersHolder.getOptionValue(
-        CliConstants.WORKER_LAUNCH_CMD);
-    String psLaunchCommand = parametersHolder.getOptionValue(
-        CliConstants.PS_LAUNCH_CMD);
-
     // Localizations
     // Localizations
     List<String> localizationsStr = parametersHolder.getOptionValues(
     List<String> localizationsStr = parametersHolder.getOptionValues(
         CliConstants.LOCALIZATION);
         CliConstants.LOCALIZATION);
@@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
         .getOptionValues(CliConstants.ARG_CONF);
         .getOptionValues(CliConstants.ARG_CONF);
 
 
     this.setInputPath(input).setCheckpointPath(jobDir)
     this.setInputPath(input).setCheckpointPath(jobDir)
-        .setNumPS(nPS).setNumWorkers(nWorkers)
-        .setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
-        .setPsResource(psResource)
-        .setTensorboardEnabled(tensorboard)
         .setKeytab(kerberosKeytab)
         .setKeytab(kerberosKeytab)
         .setPrincipal(kerberosPrincipal)
         .setPrincipal(kerberosPrincipal)
         .setDistributeKeytab(distributeKerberosKeytab)
         .setDistributeKeytab(distributeKerberosKeytab)
@@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
     super.updateParameters(parametersHolder, clientContext);
     super.updateParameters(parametersHolder, clientContext);
   }
   }
 
 
+  abstract void executePostOperations(ClientContext clientContext)
+      throws IOException;
+
+  void setDefaultDirs(ClientContext clientContext) throws IOException {
+    // Create directories if needed
+    String jobDir = getCheckpointPath();
+    if (jobDir == null) {
+      jobDir = getJobDir(clientContext);
+      setCheckpointPath(jobDir);
+    }
+
+    if (getNumWorkers() > 0) {
+      String savedModelDir = getSavedModelPath();
+      if (savedModelDir == null) {
+        savedModelDir = jobDir;
+        setSavedModelPath(savedModelDir);
+      }
+    }
+  }
+
+  private String getJobDir(ClientContext clientContext) throws IOException {
+    RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
+    if (getNumWorkers() > 0) {
+      return rdm.getJobCheckpointDir(getName(), true).toString();
+    } else {
+      // when #workers == 0, it means we only launch TB. In that case,
+      // point job dir to root dir so all job's metrics will be shown.
+      return rdm.getUserRootFolder().toString();
+    }
+  }
+
+  public abstract List<String> getLaunchCommands();
+
   public String getInputPath() {
   public String getInputPath() {
     return input;
     return input;
   }
   }
@@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
     return this;
     return this;
   }
   }
 
 
-  public int getNumWorkers() {
-    return numWorkers;
-  }
-
-  public RunJobParameters setNumWorkers(int numWorkers) {
-    this.numWorkers = numWorkers;
-    return this;
-  }
-
-  public int getNumPS() {
-    return numPS;
-  }
-
-  public RunJobParameters setNumPS(int numPS) {
-    this.numPS = numPS;
-    return this;
-  }
-
-  public Resource getWorkerResource() {
-    return workerResource;
-  }
-
-  public RunJobParameters setWorkerResource(Resource workerResource) {
-    this.workerResource = workerResource;
-    return this;
-  }
-
-  public Resource getPsResource() {
-    return psResource;
-  }
-
-  public RunJobParameters setPsResource(Resource psResource) {
-    this.psResource = psResource;
-    return this;
-  }
-
-  public boolean isTensorboardEnabled() {
-    return tensorboardEnabled;
-  }
-
-  public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
-    this.tensorboardEnabled = tensorboardEnabled;
-    return this;
-  }
-
-  public String getWorkerLaunchCmd() {
-    return workerLaunchCmd;
-  }
-
-  public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
-    this.workerLaunchCmd = workerLaunchCmd;
-    return this;
-  }
-
-  public String getPSLaunchCmd() {
-    return psLaunchCmd;
-  }
-
-  public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
-    this.psLaunchCmd = psLaunchCmd;
-    return this;
-  }
-
   public boolean isWaitJobFinish() {
   public boolean isWaitJobFinish() {
     return waitJobFinish;
     return waitJobFinish;
   }
   }
 
 
-
-  public String getPsDockerImage() {
-    return psDockerImage;
-  }
-
-  public void setPsDockerImage(String psDockerImage) {
-    this.psDockerImage = psDockerImage;
-  }
-
-  public String getWorkerDockerImage() {
-    return workerDockerImage;
-  }
-
-  public void setWorkerDockerImage(String workerDockerImage) {
-    this.workerDockerImage = workerDockerImage;
-  }
-
-  public boolean isDistributed() {
-    return distributed;
-  }
-
-  public Resource getTensorboardResource() {
-    return tensorboardResource;
-  }
-
-  public void setTensorboardResource(Resource tensorboardResource) {
-    this.tensorboardResource = tensorboardResource;
-  }
-
-  public String getTensorboardDockerImage() {
-    return tensorboardDockerImage;
-  }
-
-  public void setTensorboardDockerImage(String tensorboardDockerImage) {
-    this.tensorboardDockerImage = tensorboardDockerImage;
-  }
-
   public List<Quicklink> getQuicklinks() {
   public List<Quicklink> getQuicklinks() {
     return quicklinks;
     return quicklinks;
   }
   }
@@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
     this.distributed = distributed;
     this.distributed = distributed;
   }
   }
 
 
+  RoleParameters getWorkerParameters(ClientContext clientContext,
+      ParametersHolder parametersHolder, String input)
+      throws ParseException, YarnException, IOException {
+    int nWorkers = getNumberOfWorkers(parametersHolder, input);
+    Resource workerResource =
+        determineWorkerResource(parametersHolder, nWorkers, clientContext);
+    String workerDockerImage =
+        parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
+    String workerLaunchCmd =
+        parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
+    return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
+        workerLaunchCmd, workerDockerImage, workerResource);
+  }
+
+  private Resource determineWorkerResource(ParametersHolder parametersHolder,
+      int nWorkers, ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    if (nWorkers > 0) {
+      String workerResourceStr =
+          parametersHolder.getOptionValue(CliConstants.WORKER_RES);
+      if (workerResourceStr == null) {
+        throw new ParseException(
+            "--" + CliConstants.WORKER_RES + " is absent.");
+      }
+      return ResourceUtils.createResourceFromString(workerResourceStr,
+          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    }
+    return null;
+  }
+
+  private int getNumberOfWorkers(ParametersHolder parametersHolder,
+      String input) throws ParseException, YarnException {
+    int nWorkers = 1;
+    if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
+      nWorkers = Integer
+          .parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
+      // Only check null value.
+      // Training job shouldn't ignore INPUT_PATH option
+      // But if nWorkers is 0, INPUT_PATH can be ignored because
+      // user can only run Tensorboard
+      if (null == input && 0 != nWorkers) {
+        throw new ParseException(
+            "\"--" + CliConstants.INPUT_PATH + "\" is absent");
+      }
+    }
+    return nWorkers;
+  }
+
+  public String getWorkerLaunchCmd() {
+    return workerParameters.getLaunchCommand();
+  }
+
+  public void setWorkerLaunchCmd(String launchCmd) {
+    workerParameters.setLaunchCommand(launchCmd);
+  }
+
+  public int getNumWorkers() {
+    return workerParameters.getReplicas();
+  }
+
+  public void setNumWorkers(int numWorkers) {
+    workerParameters.setReplicas(numWorkers);
+  }
+
+  public Resource getWorkerResource() {
+    return workerParameters.getResource();
+  }
+
+  public void setWorkerResource(Resource resource) {
+    workerParameters.setResource(resource);
+  }
+
+  public String getWorkerDockerImage() {
+    return workerParameters.getDockerImage();
+  }
+
+  public void setWorkerDockerImage(String image) {
+    workerParameters.setDockerImage(image);
+  }
+
+  public boolean isDistributed() {
+    return distributed;
+  }
+
   @VisibleForTesting
   @VisibleForTesting
   public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
   public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
     @Override
     @Override

+ 215 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/TensorFlowRunJobParameters.java

@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Parameters for TensorFlow job.
+ */
+public class TensorFlowRunJobParameters extends RunJobParameters {
+  private boolean tensorboardEnabled;
+  private RoleParameters psParameters =
+      RoleParameters.createEmpty(TensorFlowRole.PS);
+  private RoleParameters tensorBoardParameters =
+      RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
+
+  @Override
+  public void updateParameters(ParametersHolder parametersHolder,
+      ClientContext clientContext)
+      throws ParseException, IOException, YarnException {
+    super.updateParameters(parametersHolder, clientContext);
+
+    String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
+    this.workerParameters =
+        getWorkerParameters(clientContext, parametersHolder, input);
+    this.psParameters = getPSParameters(clientContext, parametersHolder);
+    this.distributed = determineIfDistributed(workerParameters.getReplicas(),
+        psParameters.getReplicas());
+
+    if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
+      this.tensorboardEnabled = true;
+      this.tensorBoardParameters =
+          getTensorBoardParameters(parametersHolder, clientContext);
+    }
+    executePostOperations(clientContext);
+  }
+
+  @Override
+  void executePostOperations(ClientContext clientContext) throws IOException {
+    // Set default job dir / saved model dir, etc.
+    setDefaultDirs(clientContext);
+    replacePatternsInParameters(clientContext);
+  }
+
+  private void replacePatternsInParameters(ClientContext clientContext)
+      throws IOException {
+    if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
+      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
+          getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
+      setPSLaunchCmd(afterReplace);
+    }
+
+    if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
+      String afterReplace =
+          CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
+              clientContext.getRemoteDirectoryManager());
+      setWorkerLaunchCmd(afterReplace);
+    }
+  }
+
+  @Override
+  public List<String> getLaunchCommands() {
+    return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
+  }
+
+  private boolean determineIfDistributed(int nWorkers, int nPS)
+      throws ParseException {
+    // Check #workers and #ps.
+    // When distributed training is required
+    if (nWorkers >= 2 && nPS > 0) {
+      return true;
+    } else if (nWorkers <= 1 && nPS > 0) {
+      throw new ParseException("Only specified one worker but non-zero PS, "
+          + "please double check.");
+    }
+    return false;
+  }
+
+  private RoleParameters getPSParameters(ClientContext clientContext,
+      ParametersHolder parametersHolder)
+      throws YarnException, IOException, ParseException {
+    int nPS = getNumberOfPS(parametersHolder);
+    Resource psResource =
+        determinePSResource(parametersHolder, nPS, clientContext);
+    String psDockerImage =
+        parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
+    String psLaunchCommand =
+        parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
+    return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
+        psDockerImage, psResource);
+  }
+
+  private Resource determinePSResource(ParametersHolder parametersHolder,
+      int nPS, ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    if (nPS > 0) {
+      String psResourceStr =
+          parametersHolder.getOptionValue(CliConstants.PS_RES);
+      if (psResourceStr == null) {
+        throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
+      }
+      return ResourceUtils.createResourceFromString(psResourceStr,
+          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    }
+    return null;
+  }
+
+  private int getNumberOfPS(ParametersHolder parametersHolder)
+      throws YarnException {
+    int nPS = 0;
+    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
+      nPS =
+          Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
+    }
+    return nPS;
+  }
+
+  private RoleParameters getTensorBoardParameters(
+      ParametersHolder parametersHolder, ClientContext clientContext)
+      throws YarnException, IOException {
+    String tensorboardResourceStr =
+        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
+    if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
+      tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
+    }
+    Resource tensorboardResource =
+        ResourceUtils.createResourceFromString(tensorboardResourceStr,
+            clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    String tensorboardDockerImage =
+        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
+    return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
+        tensorboardDockerImage, tensorboardResource);
+  }
+
+  public int getNumPS() {
+    return psParameters.getReplicas();
+  }
+
+  public void setNumPS(int numPS) {
+    psParameters.setReplicas(numPS);
+  }
+
+  public Resource getPsResource() {
+    return psParameters.getResource();
+  }
+
+  public void setPsResource(Resource resource) {
+    psParameters.setResource(resource);
+  }
+
+  public String getPsDockerImage() {
+    return psParameters.getDockerImage();
+  }
+
+  public void setPsDockerImage(String image) {
+    psParameters.setDockerImage(image);
+  }
+
+  public String getPSLaunchCmd() {
+    return psParameters.getLaunchCommand();
+  }
+
+  public void setPSLaunchCmd(String launchCmd) {
+    psParameters.setLaunchCommand(launchCmd);
+  }
+
+  public boolean isTensorboardEnabled() {
+    return tensorboardEnabled;
+  }
+
+  public Resource getTensorboardResource() {
+    return tensorBoardParameters.getResource();
+  }
+
+  public void setTensorboardResource(Resource resource) {
+    tensorBoardParameters.setResource(resource);
+  }
+
+  public String getTensorboardDockerImage() {
+    return tensorBoardParameters.getDockerImage();
+  }
+
+  public void setTensorboardDockerImage(String image) {
+    tensorBoardParameters.setDockerImage(image);
+  }
+
+}

+ 20 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes that hold run job parameters for
+ * TensorFlow / PyTorch jobs.
+ */
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;

+ 9 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/yaml/Spec.java

@@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
 public class Spec {
 public class Spec {
   private String name;
   private String name;
   private String jobType;
   private String jobType;
+  private String framework;
 
 
   public String getJobType() {
   public String getJobType() {
     return jobType;
     return jobType;
@@ -38,4 +39,12 @@ public class Spec {
   public void setName(String name) {
   public void setName(String name) {
     this.name = name;
     this.name = name;
   }
   }
+
+  public String getFramework() {
+    return framework;
+  }
+
+  public void setFramework(String framework) {
+    this.framework = framework;
+  }
 }
 }

+ 59 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/Framework.java

@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Represents the type of Machine learning framework to work with.
+ */
+public enum Framework {
+  TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
+
+  private String value;
+
+  Framework(String value) {
+    this.value = value;
+  }
+
+  public String getValue() {
+    return value;
+  }
+
+  public static Framework parseByValue(String value) {
+    for (Framework fw : Framework.values()) {
+      if (fw.value.equalsIgnoreCase(value)) {
+        return fw;
+      }
+    }
+    return null;
+  }
+
+  public static String getValues() {
+    List<String> values = Lists.newArrayList(Framework.values()).stream()
+        .map(fw -> fw.value).collect(Collectors.toList());
+    return String.join(",", values);
+  }
+
+  private static class Constants {
+    static final String TENSORFLOW_NAME = "tensorflow";
+    static final String PYTORCH_NAME = "pytorch";
+  }
+}

+ 81 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RoleParameters.java

@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+
+/**
+ * This class encapsulates data related to a particular Role.
+ * Some examples: TF Worker process, TF PS process or a PyTorch worker process.
+ */
+public class RoleParameters {
+  private final Role role;
+  private int replicas;
+  private String launchCommand;
+  private String dockerImage;
+  private Resource resource;
+
+  public RoleParameters(Role role, int replicas,
+      String launchCommand, String dockerImage, Resource resource) {
+    this.role = role;
+    this.replicas = replicas;
+    this.launchCommand = launchCommand;
+    this.dockerImage = dockerImage;
+    this.resource = resource;
+  }
+
+  public static RoleParameters createEmpty(Role role) {
+    return new RoleParameters(role, 0, null, null, null);
+  }
+
+  public Role getRole() {
+    return role;
+  }
+
+  public int getReplicas() {
+    return replicas;
+  }
+
+  public String getLaunchCommand() {
+    return launchCommand;
+  }
+
+  public void setLaunchCommand(String launchCommand) {
+    this.launchCommand = launchCommand;
+  }
+
+  public String getDockerImage() {
+    return dockerImage;
+  }
+
+  public void setDockerImage(String dockerImage) {
+    this.dockerImage = dockerImage;
+  }
+
+  public Resource getResource() {
+    return resource;
+  }
+
+  public void setResource(Resource resource) {
+    this.resource = resource;
+  }
+
+  public void setReplicas(int replicas) {
+    this.replicas = replicas;
+  }
+}

+ 123 - 106
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/RunJobCli.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RunJobCli.java

@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 /**
 /**
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -12,7 +28,7 @@
  * limitations under the License. See accompanying LICENSE file.
  * limitations under the License. See accompanying LICENSE file.
  */
  */
 
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
 
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.CommandLine;
@@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.FileUtils;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.Command;
 import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
 import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
@@ -44,17 +64,25 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
+/**
+ * This purpose of this class is to handle / parse CLI arguments related to
+ * the run job Submarine command.
+ */
 public class RunJobCli extends AbstractCli {
 public class RunJobCli extends AbstractCli {
   private static final Logger LOG =
   private static final Logger LOG =
       LoggerFactory.getLogger(RunJobCli.class);
       LoggerFactory.getLogger(RunJobCli.class);
-  private static final String YAML_PARSE_FAILED = "Failed to parse " +
+  private static final String CAN_BE_USED_WITH_TF_PYTORCH =
+      "Can be used with TensorFlow or PyTorch frameworks.";
+  private static final String CAN_BE_USED_WITH_TF_ONLY =
+      "Can only be used with TensorFlow framework.";
+  public static final String YAML_PARSE_FAILED = "Failed to parse " +
       "YAML config";
       "YAML config";
 
 
-  private Options options;
-  private RunJobParameters parameters = new RunJobParameters();
 
 
+  private Options options;
   private JobSubmitter jobSubmitter;
   private JobSubmitter jobSubmitter;
   private JobMonitor jobMonitor;
   private JobMonitor jobMonitor;
+  private ParametersHolder parametersHolder;
 
 
   public RunJobCli(ClientContext cliContext) {
   public RunJobCli(ClientContext cliContext) {
     this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
     this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
@@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
   }
   }
 
 
   @VisibleForTesting
   @VisibleForTesting
-  RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
+  public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
       JobMonitor jobMonitor) {
       JobMonitor jobMonitor) {
     super(cliContext);
     super(cliContext);
     this.options = generateOptions();
     this.options = generateOptions();
@@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
     Options options = new Options();
     Options options = new Options();
     options.addOption(CliConstants.YAML_CONFIG, true,
     options.addOption(CliConstants.YAML_CONFIG, true,
         "Config file (in YAML format)");
         "Config file (in YAML format)");
+    options.addOption(CliConstants.FRAMEWORK, true,
+        String.format("Framework to use. Valid values are: %s! " +
+                "The default framework is Tensorflow.",
+            Framework.getValues()));
     options.addOption(CliConstants.NAME, true, "Name of the job");
     options.addOption(CliConstants.NAME, true, "Name of the job");
     options.addOption(CliConstants.INPUT_PATH, true,
     options.addOption(CliConstants.INPUT_PATH, true,
         "Input of the job, could be local or other FS directory");
         "Input of the job, could be local or other FS directory");
@@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
     options.addOption(CliConstants.SAVED_MODEL_PATH, true,
     options.addOption(CliConstants.SAVED_MODEL_PATH, true,
         "Model exported path (savedmodel) of the job, which is needed when "
         "Model exported path (savedmodel) of the job, which is needed when "
             + "exported model is not placed under ${checkpoint_path}"
             + "exported model is not placed under ${checkpoint_path}"
-            + "could be local or other FS directory. This will be used to serve.");
-    options.addOption(CliConstants.N_WORKERS, true,
-        "Number of worker tasks of the job, by default it's 1");
-    options.addOption(CliConstants.N_PS, true,
-        "Number of PS tasks of the job, by default it's 0");
-    options.addOption(CliConstants.WORKER_RES, true,
-        "Resource of each worker, for example "
-            + "memory-mb=2048,vcores=2,yarn.io/gpu=2");
-    options.addOption(CliConstants.PS_RES, true,
-        "Resource of each PS, for example "
-            + "memory-mb=2048,vcores=2,yarn.io/gpu=2");
+            + "could be local or other FS directory. " +
+            "This will be used to serve.");
     options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
     options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
     options.addOption(CliConstants.QUEUE, true,
     options.addOption(CliConstants.QUEUE, true,
         "Name of queue to run the job, by default it uses default queue");
         "Name of queue to run the job, by default it uses default queue");
-    options.addOption(CliConstants.TENSORBOARD, false,
-        "Should we run TensorBoard"
-            + " for this job? By default it's disabled");
-    options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
-        "Specify resources of Tensorboard, by default it is "
-            + CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
-    options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
-        "Specify Tensorboard docker image. when this is not "
-            + "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
-            + " as default.");
-    options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
-        "Commandline of worker, arguments will be "
-            + "directly used to launch the worker");
-    options.addOption(CliConstants.PS_LAUNCH_CMD, true,
-        "Commandline of worker, arguments will be "
-            + "directly used to launch the PS");
+
+    addWorkerOptions(options);
+    addPSOptions(options);
+    addTensorboardOptions(options);
+
     options.addOption(CliConstants.ENV, true,
     options.addOption(CliConstants.ENV, true,
         "Common environment variable of worker/ps");
         "Common environment variable of worker/ps");
     options.addOption(CliConstants.VERBOSE, false,
     options.addOption(CliConstants.VERBOSE, false,
         "Print verbose log for troubleshooting");
         "Print verbose log for troubleshooting");
     options.addOption(CliConstants.WAIT_JOB_FINISH, false,
     options.addOption(CliConstants.WAIT_JOB_FINISH, false,
         "Specified when user want to wait the job finish");
         "Specified when user want to wait the job finish");
-    options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
-        "Specify docker image for PS, when this is not specified, PS uses --"
-            + CliConstants.DOCKER_IMAGE + " as default.");
-    options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
-        "Specify docker image for WORKER, when this is not specified, WORKER "
-            + "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
     options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
     options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
         + "web UI shows link to given role instance and port. When "
         + "web UI shows link to given role instance and port. When "
         + "--tensorboard is specified, quicklink to tensorboard instance will "
         + "--tensorboard is specified, quicklink to tensorboard instance will "
@@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
     return options;
     return options;
   }
   }
 
 
-  private void replacePatternsInParameters() throws IOException {
-    if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd()
-        .isEmpty()) {
-      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
-          parameters.getPSLaunchCmd(), parameters,
-          clientContext.getRemoteDirectoryManager());
-      parameters.setPSLaunchCmd(afterReplace);
-    }
+  private void addWorkerOptions(Options options) {
+    options.addOption(CliConstants.N_WORKERS, true,
+        "Number of worker tasks of the job, by default it's 1." +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
+        "Specify docker image for WORKER, when this is not specified, WORKER "
+            + "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
+        "Commandline of worker, arguments will be "
+            + "directly used to launch the worker" +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_RES, true,
+        "Resource of each worker, for example "
+            + "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+  }
 
 
-    if (parameters.getWorkerLaunchCmd() != null && !parameters
-        .getWorkerLaunchCmd().isEmpty()) {
-      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
-          parameters.getWorkerLaunchCmd(), parameters,
-          clientContext.getRemoteDirectoryManager());
-      parameters.setWorkerLaunchCmd(afterReplace);
-    }
+  private void addPSOptions(Options options) {
+    options.addOption(CliConstants.N_PS, true,
+        "Number of PS tasks of the job, by default it's 0. " +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
+        "Specify docker image for PS, when this is not specified, PS uses --"
+            + CliConstants.DOCKER_IMAGE + " as default." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_LAUNCH_CMD, true,
+        "Commandline of worker, arguments will be "
+            + "directly used to launch the PS" +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_RES, true,
+        "Resource of each PS, for example "
+            + "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
+            CAN_BE_USED_WITH_TF_ONLY);
+  }
+
+  private void addTensorboardOptions(Options options) {
+    options.addOption(CliConstants.TENSORBOARD, false,
+        "Should we run TensorBoard"
+            + " for this job? By default it's disabled." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
+        "Specify resources of Tensorboard, by default it is "
+            + CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
+        "Specify Tensorboard docker image. when this is not "
+            + "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+            + " as default." +
+            CAN_BE_USED_WITH_TF_ONLY);
   }
   }
 
 
   private void parseCommandLineAndGetRunJobParameters(String[] args)
   private void parseCommandLineAndGetRunJobParameters(String[] args)
       throws ParseException, IOException, YarnException {
       throws ParseException, IOException, YarnException {
     try {
     try {
-      // Do parsing
       GnuParser parser = new GnuParser();
       GnuParser parser = new GnuParser();
       CommandLine cli = parser.parse(options, args);
       CommandLine cli = parser.parse(options, args);
-      ParametersHolder parametersHolder = createParametersHolder(cli);
-      parameters.updateParameters(parametersHolder, clientContext);
+      parametersHolder = createParametersHolder(cli);
+      parametersHolder.updateParameters(clientContext);
     } catch (ParseException e) {
     } catch (ParseException e) {
       LOG.error("Exception in parse: {}", e.getMessage());
       LOG.error("Exception in parse: {}", e.getMessage());
       printUsages();
       printUsages();
       throw e;
       throw e;
     }
     }
-
-    // Set default job dir / saved model dir, etc.
-    setDefaultDirs();
-
-    // replace patterns
-    replacePatternsInParameters();
   }
   }
 
 
-  private ParametersHolder createParametersHolder(CommandLine cli) {
+  private ParametersHolder createParametersHolder(CommandLine cli)
+      throws ParseException, YarnException {
     String yamlConfigFile =
     String yamlConfigFile =
         cli.getOptionValue(CliConstants.YAML_CONFIG);
         cli.getOptionValue(CliConstants.YAML_CONFIG);
     if (yamlConfigFile != null) {
     if (yamlConfigFile != null) {
       YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
       YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
-      if (yamlConfig == null) {
-        throw new YamlParseException(String.format(
-            YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
-      } else if (yamlConfig.getConfigs() == null) {
-        throw new YamlParseException(String.format(YAML_PARSE_FAILED +
-            ", config section should be defined, but it cannot be found in " +
-            "YAML file '%s'!", yamlConfigFile));
-      }
+      checkYamlConfig(yamlConfigFile, yamlConfig);
       LOG.info("Using YAML configuration!");
       LOG.info("Using YAML configuration!");
-      return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig);
+      return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
+          Command.RUN_JOB);
     } else {
     } else {
       LOG.info("Using CLI configuration!");
       LOG.info("Using CLI configuration!");
-      return ParametersHolder.createWithCmdLine(cli);
+      return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
+    }
+  }
+
+  private void checkYamlConfig(String yamlConfigFile,
+      YamlConfigFile yamlConfig) {
+    if (yamlConfig == null) {
+      throw new YamlParseException(String.format(
+          YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
+    } else if (yamlConfig.getConfigs() == null) {
+      throw new YamlParseException(String.format(YAML_PARSE_FAILED +
+          ", config section should be defined, but it cannot be found in " +
+          "YAML file '%s'!", yamlConfigFile));
     }
     }
   }
   }
 
 
@@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
         e);
         e);
   }
   }
 
 
-  private void setDefaultDirs() throws IOException {
-    // Create directories if needed
-    String jobDir = parameters.getCheckpointPath();
-    if (null == jobDir) {
-      if (parameters.getNumWorkers() > 0) {
-        jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
-            parameters.getName(), true).toString();
-      } else {
-        // when #workers == 0, it means we only launch TB. In that case,
-        // point job dir to root dir so all job's metrics will be shown.
-        jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
-            .toString();
-      }
-      parameters.setCheckpointPath(jobDir);
-    }
-
-    if (parameters.getNumWorkers() > 0) {
-      // Only do this when #worker > 0
-      String savedModelDir = parameters.getSavedModelPath();
-      if (null == savedModelDir) {
-        savedModelDir = jobDir;
-        parameters.setSavedModelPath(savedModelDir);
-      }
-    }
-  }
-
-  private void storeJobInformation(String jobName, ApplicationId applicationId,
-      String[] args) throws IOException {
+  private void storeJobInformation(RunJobParameters parameters,
+      ApplicationId applicationId, String[] args) throws IOException {
+    String jobName = parameters.getName();
     Map<String, String> jobInfo = new HashMap<>();
     Map<String, String> jobInfo = new HashMap<>();
     jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
     jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
     jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
     jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
@@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
     }
     }
 
 
     parseCommandLineAndGetRunJobParameters(args);
     parseCommandLineAndGetRunJobParameters(args);
-    ApplicationId applicationId = this.jobSubmitter.submitJob(parameters);
-    storeJobInformation(parameters.getName(), applicationId, args);
+    ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
+    RunJobParameters parameters =
+        (RunJobParameters) parametersHolder.getParameters();
+    storeJobInformation(parameters, applicationId, args);
     if (parameters.isWaitJobFinish()) {
     if (parameters.isWaitJobFinish()) {
       this.jobMonitor.waitTrainingFinal(parameters.getName());
       this.jobMonitor.waitTrainingFinal(parameters.getName());
     }
     }
@@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
 
 
   @VisibleForTesting
   @VisibleForTesting
   public RunJobParameters getRunJobParameters() {
   public RunJobParameters getRunJobParameters() {
-    return parameters;
+    return (RunJobParameters) parametersHolder.getParameters();
   }
   }
 }
 }

+ 19 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/package-info.java

@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes that are related to the run job command.
+ */
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;

+ 54 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/PyTorchRole.java

@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. See accompanying LICENSE file.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+/**
+ * Enum to represent a PyTorch Role.
+ */
+public enum PyTorchRole implements Role {
+  PRIMARY_WORKER("master"),
+  WORKER("worker");
+
+  private String compName;
+
+  PyTorchRole(String compName) {
+    this.compName = compName;
+  }
+
+  public String getComponentName() {
+    return compName;
+  }
+
+  @Override
+  public String getName() {
+    return name();
+  }
+}

+ 25 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Role.java

@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+/**
+ * Interface for a Role.
+ */
+public interface Role {
+  String getComponentName();
+  String getName();
+}

+ 58 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Runtime.java

@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+import com.google.common.collect.Lists;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Represents the type of Runtime.
+ */
+public enum Runtime {
+  TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
+
+  private String value;
+
+  Runtime(String value) {
+    this.value = value;
+  }
+
+  public String getValue() {
+    return value;
+  }
+
+  public static Runtime parseByValue(String value) {
+    for (Runtime rt : Runtime.values()) {
+      if (rt.value.equalsIgnoreCase(value)) {
+        return rt;
+      }
+    }
+    return null;
+  }
+
+  public static String getValues() {
+    List<String> values = Lists.newArrayList(Runtime.values()).stream()
+        .map(rt -> rt.value).collect(Collectors.toList());
+    return String.join(",", values);
+  }
+
+  public static class Constants {
+    public static final String TONY = "tony";
+    public static final String YARN_SERVICE = "yarnservice";
+  }
+}

+ 11 - 2
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TaskType.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TensorFlowRole.java

@@ -14,7 +14,10 @@
 
 
 package org.apache.hadoop.yarn.submarine.common.api;
 package org.apache.hadoop.yarn.submarine.common.api;
 
 
-public enum TaskType {
+/**
+ * Enum to represent a TensorFlow Role.
+ */
+public enum TensorFlowRole implements Role {
   PRIMARY_WORKER("master"),
   PRIMARY_WORKER("master"),
   WORKER("worker"),
   WORKER("worker"),
   PS("ps"),
   PS("ps"),
@@ -22,11 +25,17 @@ public enum TaskType {
 
 
   private String compName;
   private String compName;
 
 
-  TaskType(String compName) {
+  TensorFlowRole(String compName) {
     this.compName = compName;
     this.compName = compName;
   }
   }
 
 
+  @Override
   public String getComponentName() {
   public String getComponentName() {
     return compName;
     return compName;
   }
   }
+
+  @Override
+  public String getName() {
+    return name();
+  }
 }
 }

+ 5 - 5
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/common/JobSubmitter.java

@@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
 
 
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
 
 
 import java.io.IOException;
 import java.io.IOException;
 
 
 /**
 /**
- * Submit job to cluster master
+ * Submit job to cluster master.
  */
  */
 public interface JobSubmitter {
 public interface JobSubmitter {
   /**
   /**
-   * Submit job to cluster
+   * Submit a job to cluster.
    * @param parameters run job parameters
    * @param parameters run job parameters
-   * @return applicatioId when successfully submitted
+   * @return applicationId when successfully submitted
    * @throws YarnException for issues while contacting YARN daemons
    * @throws YarnException for issues while contacting YARN daemons
    * @throws IOException for other issues.
    * @throws IOException for other issues.
    */
    */
-  ApplicationId submitJob(RunJobParameters parameters)
+  ApplicationId submitJob(ParametersHolder parameters)
       throws IOException, YarnException;
       throws IOException, YarnException;
 }
 }

+ 7 - 0
hadoop-submarine/hadoop-submarine-core/src/site/markdown/QuickStart.md

@@ -40,6 +40,10 @@ More details, please refer to
 
 
 ```$xslt
 ```$xslt
 usage: job run
 usage: job run
+
+ -framework <arg>             Framework to use.
+                              Valid values are: tensorflow, pytorch.
+                              The default framework is Tensorflow.
  -checkpoint_path <arg>       Training output directory of the job, could
  -checkpoint_path <arg>       Training output directory of the job, could
                               be local or other FS directory. This
                               be local or other FS directory. This
                               typically includes checkpoint files and
                               typically includes checkpoint files and
@@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
 #### Commandline
 #### Commandline
 ```
 ```
 yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
 yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
+  --framework tensorflow \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
   --docker_image <your-docker-image> \
   --docker_image <your-docker-image> \
@@ -163,6 +168,7 @@ See below screenshot:
 ```
 ```
 yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
 yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
  --name tf-job-001 --docker_image <your-docker-image> \
  --name tf-job-001 --docker_image <your-docker-image> \
+ --framework tensorflow \
  --input_path hdfs://default/dataset/cifar-10-data \
  --input_path hdfs://default/dataset/cifar-10-data \
  --checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
  --checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
  --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
  --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
@@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
 yarn app -destroy tensorboard-service; \
 yarn app -destroy tensorboard-service; \
 yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
 yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
   job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
   job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
+  --framework tensorflow \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
   --num_workers 0 --tensorboard
   --num_workers 0 --tensorboard

+ 0 - 226
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java

@@ -1,226 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.hadoop.yarn.submarine.client.cli;
-
-import org.apache.commons.cli.ParseException;
-import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.MockClientContext;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
-import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
-import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
-import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
-import org.apache.hadoop.yarn.util.resource.Resources;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.io.IOException;
-
-import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-public class TestRunJobCliParsing {
-
-  @Before
-  public void before() {
-    SubmarineLogs.verboseOff();
-  }
-
-  @Test
-  public void testPrintHelp() {
-    MockClientContext mockClientContext = new MockClientContext();
-    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    JobMonitor mockJobMonitor = mock(JobMonitor.class);
-    RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
-        mockJobMonitor);
-    runJobCli.printUsages();
-  }
-
-  static MockClientContext getMockClientContext()
-      throws IOException, YarnException {
-    MockClientContext mockClientContext = new MockClientContext();
-    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
-        ApplicationId.newInstance(1234L, 1));
-    JobMonitor mockJobMonitor = mock(JobMonitor.class);
-    SubmarineStorage storage = mock(SubmarineStorage.class);
-    RuntimeFactory rtFactory = mock(RuntimeFactory.class);
-
-    when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
-    when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
-    when(rtFactory.getSubmarineStorage()).thenReturn(storage);
-
-    mockClientContext.setRuntimeFactory(rtFactory);
-    return mockClientContext;
-  }
-
-  @Test
-  public void testBasicRunJobForDistributedTraining() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
-            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
-            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
-            "--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
-            "--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
-            "--principal", "user/_HOST@domain.com", "--distribute_keytab",
-            "--verbose" });
-
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
-    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
-    assertEquals(jobRunParameters.getNumPS(), 2);
-    assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
-    assertEquals(Resources.createResource(4096, 4),
-        jobRunParameters.getPsResource());
-    assertEquals(jobRunParameters.getWorkerLaunchCmd(),
-        "python run-job.py");
-    assertEquals(Resources.createResource(2048, 2),
-        jobRunParameters.getWorkerResource());
-    assertEquals(jobRunParameters.getDockerImageName(),
-        "tf-docker:1.1.0");
-    assertEquals(jobRunParameters.getKeytab(),
-        "/keytab/path");
-    assertEquals(jobRunParameters.getPrincipal(),
-        "user/_HOST@domain.com");
-    Assert.assertTrue(jobRunParameters.isDistributeKeytab());
-    Assert.assertTrue(SubmarineLogs.isVerbose());
-  }
-
-  @Test
-  public void testBasicRunJobForSingleNodeTraining() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
-            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
-            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
-            "true", "--verbose", "--wait_job_finish" });
-
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
-    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
-    assertEquals(jobRunParameters.getNumWorkers(), 1);
-    assertEquals(jobRunParameters.getWorkerLaunchCmd(),
-        "python run-job.py");
-    assertEquals(Resources.createResource(4096, 2),
-        jobRunParameters.getWorkerResource());
-    Assert.assertTrue(SubmarineLogs.isVerbose());
-    Assert.assertTrue(jobRunParameters.isWaitJobFinish());
-  }
-
-  @Test
-  public void testNoInputPathOptionSpecified() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
-    String actualMessage = "";
-    try {
-      runJobCli.run(
-          new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-              "--checkpoint_path", "hdfs://output",
-              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
-              "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
-              "true", "--verbose", "--wait_job_finish"});
-    } catch (ParseException e) {
-      actualMessage = e.getMessage();
-      e.printStackTrace();
-    }
-    assertEquals(expectedErrorMessage, actualMessage);
-  }
-
-  /**
-   * when only run tensorboard, input_path is not needed
-   * */
-  @Test
-  public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    boolean success = true;
-    try {
-      runJobCli.run(
-          new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-              "--num_workers", "0", "--tensorboard", "--verbose",
-              "--tensorboard_resources", "memory=2G,vcores=2",
-              "--tensorboard_docker_image", "tb_docker_image:001"});
-    } catch (ParseException e) {
-      success = false;
-    }
-    Assert.assertTrue(success);
-  }
-
-  @Test
-  public void testJobWithoutName() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    String expectedErrorMessage =
-        "--" + CliConstants.NAME + " is absent";
-    String actualMessage = "";
-    try {
-      runJobCli.run(
-          new String[]{"--docker_image", "tf-docker:1.1.0",
-              "--num_workers", "0", "--tensorboard", "--verbose",
-              "--tensorboard_resources", "memory=2G,vcores=2",
-              "--tensorboard_docker_image", "tb_docker_image:001"});
-    } catch (ParseException e) {
-      actualMessage = e.getMessage();
-      e.printStackTrace();
-    }
-    assertEquals(expectedErrorMessage, actualMessage);
-  }
-
-  @Test
-  public void testLaunchCommandPatternReplace() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path",
-            "hdfs://output",
-            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
-            "python run-job.py --input=%input_path% " +
-                "--model_dir=%checkpoint_path% " +
-                "--export_dir=%saved_model_path%/savedmodel",
-            "--worker_resources", "memory=2048,vcores=2", "--ps_resources",
-            "memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
-            "python run-ps.py --input=%input_path% " +
-                "--model_dir=%checkpoint_path%/model",
-            "--verbose" });
-
-    assertEquals(
-        "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
-            + "--export_dir=hdfs://output/savedmodel",
-        runJobCli.getRunJobParameters().getWorkerLaunchCmd());
-    assertEquals(
-        "python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
-        runJobCli.getRunJobParameters().getPSLaunchCmd());
-  }
-}

+ 6 - 5
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/YamlConfigTestUtils.java

@@ -17,7 +17,7 @@
 package org.apache.hadoop.yarn.submarine.client.cli;
 package org.apache.hadoop.yarn.submarine.client.cli;
 
 
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.FileUtils;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.yaml.snakeyaml.Yaml;
 import org.yaml.snakeyaml.Yaml;
 import org.yaml.snakeyaml.constructor.Constructor;
 import org.yaml.snakeyaml.constructor.Constructor;
@@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
 
 
   private YamlConfigTestUtils() {}
   private YamlConfigTestUtils() {}
 
 
-  static void deleteFile(File file) {
+  public static void deleteFile(File file) {
     if (file != null) {
     if (file != null) {
       file.delete();
       file.delete();
     }
     }
   }
   }
 
 
-  static YamlConfigFile readYamlConfigFile(String filename) {
+  public static YamlConfigFile readYamlConfigFile(String filename) {
     Constructor constructor = new Constructor(YamlConfigFile.class);
     Constructor constructor = new Constructor(YamlConfigFile.class);
     constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
     constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
     Yaml yaml = new Yaml(constructor);
     Yaml yaml = new Yaml(constructor);
@@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
     return yaml.loadAs(inputStream, YamlConfigFile.class);
     return yaml.loadAs(inputStream, YamlConfigFile.class);
   }
   }
 
 
-  static File createTempFileWithContents(String filename) throws IOException {
+  public static File createTempFileWithContents(String filename)
+      throws IOException {
     InputStream inputStream = YamlConfigTestUtils.class
     InputStream inputStream = YamlConfigTestUtils.class
         .getClassLoader()
         .getClassLoader()
         .getResourceAsStream(filename);
         .getResourceAsStream(filename);
@@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
     return targetFile;
     return targetFile;
   }
   }
 
 
-  static File createEmptyTempFile() throws IOException {
+  public static File createEmptyTempFile() throws IOException {
     return File.createTempFile("test", ".yaml");
     return File.createTempFile("test", ".yaml");
   }
   }
 
 

+ 129 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommon.java

@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.commons.cli.MissingArgumentException;
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import java.io.IOException;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * This class contains some test methods to test common functionality
+ * (including TF / PyTorch) of the run job Submarine command.
+ */
+public class TestRunJobCliParsingCommon {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  public static MockClientContext getMockClientContext()
+      throws IOException, YarnException {
+    MockClientContext mockClientContext = new MockClientContext();
+    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
+    when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
+        .thenReturn(ApplicationId.newInstance(1235L, 1));
+
+    JobMonitor mockJobMonitor = mock(JobMonitor.class);
+    SubmarineStorage storage = mock(SubmarineStorage.class);
+    RuntimeFactory rtFactory = mock(RuntimeFactory.class);
+
+    when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
+    when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
+    when(rtFactory.getSubmarineStorage()).thenReturn(storage);
+
+    mockClientContext.setRuntimeFactory(rtFactory);
+    return mockClientContext;
+  }
+
+  @Test
+  public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+    RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
+    assertTrue("Default Framework should be TensorFlow!",
+        runJobParameters instanceof TensorFlowRunJobParameters);
+  }
+
+  @Test
+  public void testEmptyFrameworkOption() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(MissingArgumentException.class);
+    expectedException.expectMessage("Missing argument for option: framework");
+
+    runJobCli.run(
+        new String[]{"--framework", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+  }
+
+  @Test
+  public void testInvalidFrameworkOption() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("Failed to parse Framework type");
+
+    runJobCli.run(
+        new String[]{"--framework", "bla", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+  }
+}

+ 252 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommonYaml.java

@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * This class contains some test methods to test common YAML parsing
+ * functionality (including TF / PyTorch) of the run job Submarine command.
+ */
+public class TestRunJobCliParsingCommonYaml {
+  private static final String DIR_NAME = "runjob-common-yaml";
+  private static final String TF_DIR = "runjob-pytorch-yaml";
+  private File yamlConfig;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @After
+  public void after() {
+    YamlConfigTestUtils.deleteFile(yamlConfig);
+  }
+
+  @BeforeClass
+  public static void configureResourceTypes() {
+    List<ResourceTypeInfo> resTypes = new ArrayList<>(
+        ResourceUtils.getResourcesTypeInfo());
+    resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
+    ResourceUtils.reinitializeResources(resTypes);
+  }
+
+  @Rule
+  public ExpectedException exception = ExpectedException.none();
+
+  @Test
+  public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        TF_DIR + "/valid-config.yaml");
+    String[] args = new String[] {"--name", "my-job",
+        "--docker_image", "tf-docker:1.1.0",
+        "-f", yamlConfig.getAbsolutePath() };
+
+    exception.expect(YarnException.class);
+    exception.expectMessage("defined both with YAML config and with " +
+        "CLI argument");
+
+    runJobCli.run(args);
+  }
+
+  @Test
+  public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
+      throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        TF_DIR + "/valid-config.yaml");
+    String[] args = new String[] {"--name", "my-job",
+        "--quicklink", "AAA=http://master-0:8321",
+        "--quicklink", "BBB=http://worker-0:1234",
+        "-f", yamlConfig.getAbsolutePath()};
+
+    exception.expect(YarnException.class);
+    exception.expectMessage("defined both with YAML config and with " +
+        "CLI argument");
+
+    runJobCli.run(args);
+  }
+
+  @Test
+  public void testFalseValuesForBooleanFields() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/test-false-values.yaml");
+    runJobCli.run(
+        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertFalse(jobRunParameters.isDistributeKeytab());
+    assertFalse(jobRunParameters.isWaitJobFinish());
+    assertFalse(tensorFlowParams.isTensorboardEnabled());
+  }
+
+  @Test
+  public void testWrongIndentation() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/wrong-indentation.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("Failed to parse YAML config, details:");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testWrongFilename() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    exception.expect(YamlParseException.class);
+    runJobCli.run(
+        new String[]{"-f", "not-existing", "--verbose"});
+  }
+
+  @Test
+  public void testEmptyFile() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
+
+    exception.expect(YamlParseException.class);
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testNotExistingFile() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("file does not exist");
+    runJobCli.run(
+        new String[]{"-f", "blabla", "--verbose"});
+  }
+
+  @Test
+  public void testWrongPropertyName() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/wrong-property-name.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("Failed to parse YAML config, details:");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testMissingConfigsSection() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/missing-configs.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("config section should be defined, " +
+        "but it cannot be found");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testMissingSectionsShouldParsed() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/some-sections-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+
+  @Test
+  public void testAbsentFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/missing-framework.yaml");
+
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testEmptyFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/empty-framework.yaml");
+
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testInvalidFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-framework.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("framework should is defined, " +
+        "but it has an invalid value");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+}

+ 192 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingParameterized.java

@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+import com.google.common.collect.Lists;
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * This class contains some test methods to test common CLI parsing
+ * functionality (including TF / PyTorch) of the run job Submarine command.
+ */
+@RunWith(Parameterized.class)
+public class TestRunJobCliParsingParameterized {
+
+  private final Framework framework;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Parameterized.Parameters
+  public static Collection<Object[]> data() {
+    Collection<Object[]> params = new ArrayList<>();
+    params.add(new Object[]{Framework.TENSORFLOW });
+    params.add(new Object[]{Framework.PYTORCH });
+    return params;
+  }
+
+  public TestRunJobCliParsingParameterized(Framework framework) {
+    this.framework = framework;
+  }
+
+  private String getFrameworkName() {
+    return framework.getValue();
+  }
+
+  @Test
+  public void testPrintHelp() {
+    MockClientContext mockClientContext = new MockClientContext();
+    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
+    JobMonitor mockJobMonitor = mock(JobMonitor.class);
+    RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
+        mockJobMonitor);
+    runJobCli.printUsages();
+  }
+
+  @Test
+  public void testNoInputPathOptionSpecified() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
+        " is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", getFrameworkName(),
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--checkpoint_path", "hdfs://output",
+              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+              "--worker_resources", "memory=4g,vcores=2", "--verbose",
+              "--wait_job_finish"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testJobWithoutName() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage =
+        "--" + CliConstants.NAME + " is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", getFrameworkName(),
+              "--docker_image", "tf-docker:1.1.0",
+              "--num_workers", "0", "--verbose"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testLaunchCommandPatternReplace() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    List<String> parameters = Lists.newArrayList("--framework",
+        getFrameworkName(),
+        "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+        "--input_path", "hdfs://input", "--checkpoint_path",
+        "hdfs://output",
+        "--num_workers", "3",
+        "--worker_launch_cmd", "python run-job.py --input=%input_path% " +
+            "--model_dir=%checkpoint_path% " +
+            "--export_dir=%saved_model_path%/savedmodel",
+        "--worker_resources", "memory=2048,vcores=2");
+
+    if (framework == Framework.TENSORFLOW) {
+      parameters.addAll(Lists.newArrayList(
+          "--ps_resources", "memory=4096,vcores=4",
+          "--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
+              "--model_dir=%checkpoint_path%/model",
+          "--verbose"));
+    }
+
+    runJobCli.run(parameters.toArray(new String[0]));
+
+    RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
+
+    if (framework == Framework.TENSORFLOW) {
+      TensorFlowRunJobParameters tensorFlowParams =
+          (TensorFlowRunJobParameters) runJobParameters;
+      assertEquals(
+          "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+              + "--export_dir=hdfs://output/savedmodel",
+          tensorFlowParams.getWorkerLaunchCmd());
+      assertEquals(
+          "python run-ps.py --input=hdfs://input " +
+              "--model_dir=hdfs://output/model",
+          tensorFlowParams.getPSLaunchCmd());
+    } else if (framework == Framework.PYTORCH) {
+      PyTorchRunJobParameters pyTorchParameters =
+          (PyTorchRunJobParameters) runJobParameters;
+      assertEquals(
+          "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+              + "--export_dir=hdfs://output/savedmodel",
+          pyTorchParameters.getWorkerLaunchCmd());
+    }
+  }
+
+  private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
+    RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
+
+    if (framework == Framework.TENSORFLOW) {
+      assertTrue(RunJobParameters.class + " must be an instance of " +
+              TensorFlowRunJobParameters.class,
+          runJobParameters instanceof TensorFlowRunJobParameters);
+    } else if (framework == Framework.PYTORCH) {
+      assertTrue(RunJobParameters.class + " must be an instance of " +
+              PyTorchRunJobParameters.class,
+          runJobParameters instanceof PyTorchRunJobParameters);
+    }
+    return runJobParameters;
+  }
+
+}

+ 209 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorch.java

@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.Resources;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test class that verifies the correctness of PyTorch
+ * CLI configuration parsing.
+ */
+public class TestRunJobCliParsingPyTorch {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testBasicRunJobForSingleNodeTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--verbose",
+            "--wait_job_finish" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+            " must be an instance of " +
+            PyTorchRunJobParameters.class,
+        jobRunParameters instanceof PyTorchRunJobParameters);
+    PyTorchRunJobParameters pyTorchParams =
+        (PyTorchRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(pyTorchParams.getNumWorkers(), 1);
+    assertEquals(pyTorchParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(4096, 2),
+        pyTorchParams.getWorkerResource());
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+  }
+
+  @Test
+  public void testNumPSCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path","hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--num_ps", "2" });
+  }
+
+  @Test
+  public void testPSResourcesCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_resources", "memory=2048M,vcores=2" });
+  }
+
+  @Test
+  public void testPSDockerImageCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_docker_image", "psDockerImage" });
+  }
+
+  @Test
+  public void testPSLaunchCommandCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_launch_cmd", "psLaunchCommand" });
+  }
+
+  @Test
+  public void testTensorboardCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard" });
+  }
+
+  @Test
+  public void testTensorboardResourcesCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard_resources", "memory=2048M,vcores=2" });
+  }
+
+  @Test
+  public void testTensorboardDockerImageCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard_docker_image", "TBDockerImage" });
+  }
+
+}

+ 225 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorchYaml.java

@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
+import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+/**
+ * Test class that verifies the correctness of PyTorch
+ * YAML configuration parsing.
+ */
+public class TestRunJobCliParsingPyTorchYaml {
+  private static final String OVERRIDDEN_PREFIX = "overridden_";
+  private static final String DIR_NAME = "runjob-pytorch-yaml";
+  private File yamlConfig;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @After
+  public void after() {
+    YamlConfigTestUtils.deleteFile(yamlConfig);
+  }
+
+  @BeforeClass
+  public static void configureResourceTypes() {
+    List<ResourceTypeInfo> resTypes = new ArrayList<>(
+        ResourceUtils.getResourcesTypeInfo());
+    resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
+    ResourceUtils.reinitializeResources(resTypes);
+  }
+
+  @Rule
+  public ExpectedException exception = ExpectedException.none();
+
+  private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
+    verifyBasicConfigValues(jobRunParameters,
+        ImmutableList.of("env1=env1Value", "env2=env2Value"));
+  }
+
+  private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
+      List<String> expectedEnvs) {
+    assertEquals("testInputPath", jobRunParameters.getInputPath());
+    assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
+    assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
+
+    assertNotNull(jobRunParameters.getLocalizations());
+    assertEquals(2, jobRunParameters.getLocalizations().size());
+
+    assertNotNull(jobRunParameters.getQuicklinks());
+    assertEquals(2, jobRunParameters.getQuicklinks().size());
+
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+
+    for (String env : expectedEnvs) {
+      assertTrue(String.format(
+          "%s should be in env list of jobRunParameters!", env),
+          jobRunParameters.getEnvars().contains(env));
+    }
+  }
+
+  private void verifyWorkerValues(RunJobParameters jobRunParameters,
+      String prefix) {
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            PyTorchRunJobParameters.class,
+        jobRunParameters instanceof PyTorchRunJobParameters);
+    PyTorchRunJobParameters tensorFlowParams =
+        (PyTorchRunJobParameters) jobRunParameters;
+
+    assertEquals(3, tensorFlowParams.getNumWorkers());
+    assertEquals(prefix + "testLaunchCmdWorker",
+        tensorFlowParams.getWorkerLaunchCmd());
+    assertEquals(prefix + "testDockerImageWorker",
+        tensorFlowParams.getWorkerDockerImage());
+    assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
+        ImmutableMap.<String, String> builder()
+            .put(ResourceInformation.GPU_URI, "2").build()),
+        tensorFlowParams.getWorkerResource());
+  }
+
+  private void verifySecurityValues(RunJobParameters jobRunParameters) {
+    assertEquals("keytabPath", jobRunParameters.getKeytab());
+    assertEquals("testPrincipal", jobRunParameters.getPrincipal());
+    assertTrue(jobRunParameters.isDistributeKeytab());
+  }
+
+  @Test
+  public void testValidYamlParsing() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/valid-config.yaml");
+    runJobCli.run(
+        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, "");
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testRoleOverrides() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/valid-config-with-overrides.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testMissingPrincipalUnderSecuritySection() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/security-principal-is-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, "");
+
+    //Verify security values
+    assertEquals("keytabPath", jobRunParameters.getKeytab());
+    assertNull("Principal should be null!", jobRunParameters.getPrincipal());
+    assertTrue(jobRunParameters.isDistributeKeytab());
+  }
+
+  @Test
+  public void testMissingEnvs() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/envs-are-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
+    verifyWorkerValues(jobRunParameters, "");
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testInvalidConfigPsSectionIsDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("PS section should not be defined " +
+        "when PyTorch is the selected framework");
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-config-ps-section.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("TensorBoard section should not be defined " +
+        "when PyTorch is the selected framework");
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-config-tensorboard-section.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+}

+ 170 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlow.java

@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.Resources;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test class that verifies the correctness of TensorFlow
+ * CLI configuration parsing.
+ */
+public class TestRunJobCliParsingTensorFlow {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testNoInputPathOptionSpecified() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
+        "\" is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", "tensorflow",
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--checkpoint_path", "hdfs://output",
+              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+              "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+              "true", "--verbose", "--wait_job_finish"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testBasicRunJobForDistributedTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] { "--framework", "tensorflow",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
+            "--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
+            "--principal", "user/_HOST@domain.com", "--distribute_keytab",
+            "--verbose" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+        " must be an instance of " +
+        TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(tensorFlowParams.getNumPS(), 2);
+    assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
+    assertEquals(Resources.createResource(4096, 4),
+        tensorFlowParams.getPsResource());
+    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(2048, 2),
+        tensorFlowParams.getWorkerResource());
+    assertEquals(jobRunParameters.getDockerImageName(),
+        "tf-docker:1.1.0");
+    assertEquals(jobRunParameters.getKeytab(),
+        "/keytab/path");
+    assertEquals(jobRunParameters.getPrincipal(),
+        "user/_HOST@domain.com");
+    assertTrue(jobRunParameters.isDistributeKeytab());
+    assertTrue(SubmarineLogs.isVerbose());
+  }
+
+  @Test
+  public void testBasicRunJobForSingleNodeTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] { "--framework", "tensorflow",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+            " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(tensorFlowParams.getNumWorkers(), 1);
+    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(4096, 2),
+        tensorFlowParams.getWorkerResource());
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+  }
+
+  /**
+   * when only run tensorboard, input_path is not needed
+   * */
+  @Test
+  public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    boolean success = true;
+    try {
+      runJobCli.run(
+          new String[]{"--framework", "tensorflow",
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--num_workers", "0", "--tensorboard", "--verbose",
+              "--tensorboard_resources", "memory=2G,vcores=2",
+              "--tensorboard_docker_image", "tb_docker_image:001"});
+    } catch (ParseException e) {
+      success = false;
+    }
+    assertTrue(success);
+  }
+}

+ 45 - 159
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsingYaml.java → hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYaml.java

@@ -15,16 +15,17 @@
  */
  */
 
 
 
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
 
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableMap;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
 import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
-import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
 import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.junit.After;
 import org.junit.After;
@@ -39,19 +40,18 @@ import java.io.File;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 
 
-import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext;
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertTrue;
 
 
 /**
 /**
- * Test class that verifies the correctness of YAML configuration parsing.
+ * Test class that verifies the correctness of TF YAML configuration parsing.
  */
  */
-public class TestRunJobCliParsingYaml {
+public class TestRunJobCliParsingTensorFlowYaml {
   private static final String OVERRIDDEN_PREFIX = "overridden_";
   private static final String OVERRIDDEN_PREFIX = "overridden_";
-  private static final String DIR_NAME = "runjobcliparsing";
+  private static final String DIR_NAME = "runjob-tensorflow-yaml";
   private File yamlConfig;
   private File yamlConfig;
 
 
   @Before
   @Before
@@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
 
 
   private void verifyPsValues(RunJobParameters jobRunParameters,
   private void verifyPsValues(RunJobParameters jobRunParameters,
       String prefix) {
       String prefix) {
-    assertEquals(4, jobRunParameters.getNumPS());
-    assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(4, tensorFlowParams.getNumPS());
+    assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
     assertEquals(prefix + "testDockerImagePs",
     assertEquals(prefix + "testDockerImagePs",
-        jobRunParameters.getPsDockerImage());
+        tensorFlowParams.getPsDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
     assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
         ImmutableMap.<String, String> builder()
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "4").build()),
             .put(ResourceInformation.GPU_URI, "4").build()),
-        jobRunParameters.getPsResource());
+        tensorFlowParams.getPsResource());
   }
   }
 
 
   private void verifyWorkerValues(RunJobParameters jobRunParameters,
   private void verifyWorkerValues(RunJobParameters jobRunParameters,
       String prefix) {
       String prefix) {
-    assertEquals(3, jobRunParameters.getNumWorkers());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(3, tensorFlowParams.getNumWorkers());
     assertEquals(prefix + "testLaunchCmdWorker",
     assertEquals(prefix + "testLaunchCmdWorker",
-        jobRunParameters.getWorkerLaunchCmd());
+        tensorFlowParams.getWorkerLaunchCmd());
     assertEquals(prefix + "testDockerImageWorker",
     assertEquals(prefix + "testDockerImageWorker",
-        jobRunParameters.getWorkerDockerImage());
+        tensorFlowParams.getWorkerDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
     assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
         ImmutableMap.<String, String> builder()
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "2").build()),
             .put(ResourceInformation.GPU_URI, "2").build()),
-        jobRunParameters.getWorkerResource());
+        tensorFlowParams.getWorkerResource());
   }
   }
 
 
   private void verifySecurityValues(RunJobParameters jobRunParameters) {
   private void verifySecurityValues(RunJobParameters jobRunParameters) {
@@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
   }
   }
 
 
   private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
   private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
-    assertTrue(jobRunParameters.isTensorboardEnabled());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertTrue(tensorFlowParams.isTensorboardEnabled());
     assertEquals("tensorboardDockerImage",
     assertEquals("tensorboardDockerImage",
-        jobRunParameters.getTensorboardDockerImage());
+        tensorFlowParams.getTensorboardDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
         ImmutableMap.<String, String> builder()
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "3").build()),
             .put(ResourceInformation.GPU_URI, "3").build()),
-        jobRunParameters.getTensorboardResource());
+        tensorFlowParams.getTensorboardResource());
   }
   }
 
 
   @Test
   @Test
@@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
     verifyTensorboardValues(jobRunParameters);
     verifyTensorboardValues(jobRunParameters);
   }
   }
 
 
-  @Test
-  public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/valid-config.yaml");
-    String[] args = new String[] {"--name", "my-job",
-        "--docker_image", "tf-docker:1.1.0",
-        "-f", yamlConfig.getAbsolutePath() };
-
-    exception.expect(YarnException.class);
-    exception.expectMessage("defined both with YAML config and with " +
-        "CLI argument");
-
-    runJobCli.run(args);
-  }
-
-  @Test
-  public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
-      throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/valid-config.yaml");
-    String[] args = new String[] {"--name", "my-job",
-        "--quicklink", "AAA=http://master-0:8321",
-        "--quicklink", "BBB=http://worker-0:1234",
-        "-f", yamlConfig.getAbsolutePath()};
-
-    exception.expect(YarnException.class);
-    exception.expectMessage("defined both with YAML config and with " +
-        "CLI argument");
-
-    runJobCli.run(args);
-  }
-
   @Test
   @Test
   public void testRoleOverrides() throws Exception {
   public void testRoleOverrides() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
     verifyTensorboardValues(jobRunParameters);
     verifyTensorboardValues(jobRunParameters);
   }
   }
 
 
-  @Test
-  public void testFalseValuesForBooleanFields() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/test-false-values.yaml");
-    runJobCli.run(
-        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertFalse(jobRunParameters.isDistributeKeytab());
-    assertFalse(jobRunParameters.isWaitJobFinish());
-    assertFalse(jobRunParameters.isTensorboardEnabled());
-  }
-
-  @Test
-  public void testWrongIndentation() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/wrong-indentation.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("Failed to parse YAML config, details:");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testWrongFilename() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    exception.expect(YamlParseException.class);
-    runJobCli.run(
-        new String[]{"-f", "not-existing", "--verbose"});
-  }
-
-  @Test
-  public void testEmptyFile() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
-
-    exception.expect(YamlParseException.class);
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testNotExistingFile() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("file does not exist");
-    runJobCli.run(
-        new String[]{"-f", "blabla", "--verbose"});
-  }
-
-  @Test
-  public void testWrongPropertyName() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/wrong-property-name.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("Failed to parse YAML config, details:");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testMissingConfigsSection() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/missing-configs.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("config section should be defined, " +
-        "but it cannot be found");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testMissingSectionsShouldParsed() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/some-sections-missing.yaml");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
   @Test
   @Test
   public void testMissingPrincipalUnderSecuritySection() throws Exception {
   public void testMissingPrincipalUnderSecuritySection() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
         new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
         new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
 
 
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
     verifyBasicConfigValues(jobRunParameters);
     verifyBasicConfigValues(jobRunParameters);
     verifyPsValues(jobRunParameters, "");
     verifyPsValues(jobRunParameters, "");
     verifyWorkerValues(jobRunParameters, "");
     verifyWorkerValues(jobRunParameters, "");
     verifySecurityValues(jobRunParameters);
     verifySecurityValues(jobRunParameters);
 
 
-    assertTrue(jobRunParameters.isTensorboardEnabled());
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertTrue(tensorFlowParams.isTensorboardEnabled());
     assertNull("tensorboardDockerImage should be null!",
     assertNull("tensorboardDockerImage should be null!",
-        jobRunParameters.getTensorboardDockerImage());
+        tensorFlowParams.getTensorboardDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
         ImmutableMap.<String, String> builder()
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "3").build()),
             .put(ResourceInformation.GPU_URI, "3").build()),
-        jobRunParameters.getTensorboardResource());
+        tensorFlowParams.getTensorboardResource());
   }
   }
 
 
   @Test
   @Test

+ 8 - 9
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsingYamlStandalone.java → hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYamlStandalone.java

@@ -15,7 +15,7 @@
  */
  */
 
 
 
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
 
 
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
@@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
  * Please note that this class just tests YAML parsing,
  * Please note that this class just tests YAML parsing,
  * but only in an isolated fashion.
  * but only in an isolated fashion.
  */
  */
-public class TestRunJobCliParsingYamlStandalone {
+public class TestRunJobCliParsingTensorFlowYamlStandalone {
   private static final String OVERRIDDEN_PREFIX = "overridden_";
   private static final String OVERRIDDEN_PREFIX = "overridden_";
-  private static final String DIR_NAME = "runjobcliparsing";
-
-  @Before
-  public void before() {
-    SubmarineLogs.verboseOff();
-  }
+  private static final String DIR_NAME = "runjob-tensorflow-yaml";
 
 
   private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
   private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
     assertNotNull("Spec file should not be null!", yamlConfigFile);
     assertNotNull("Spec file should not be null!", yamlConfigFile);
@@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
     assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
     assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
   }
   }
 
 
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
   @Test
   @Test
   public void testLaunchCommandYaml() {
   public void testLaunchCommandYaml() {
     YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
     YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
@@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
     assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
     assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
     assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
     assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
   }
   }
-
 }
 }

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/empty-framework.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework:
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/invalid-framework.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: bla
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/missing-configs.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-configs.yaml


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/valid-config.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-framework.yaml


+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/some-sections-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/some-sections-missing.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/test-false-values.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/test-false-values.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/wrong-indentation.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-indentation.yaml


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/wrong-property-name.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-property-name.yaml


+ 51 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/envs-are-missing.yaml

@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 56 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-ps-section.yaml

@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    docker_image: testPSDockerImage
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 57 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-tensorboard-section.yaml

@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  docker_image: tensorboardDockerImage

+ 53 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/security-principal-is-missing.yaml

@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  distribute_keytab: true

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config-with-overrides.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: overridden_testLaunchCmdWorker
+    docker_image: overridden_testDockerImageWorker
+    envs:
+      env1: 'overridden_env1Worker'
+      env2: 'overridden_env2Worker'
+    localizations:
+    - hdfs://remote-file1:/overridden_local-filename1Worker:rw
+    - nfs://remote-file2:/overridden_local-filename2Worker:rw
+    mounts:
+    - /etc/passwd:/overridden_Worker
+    - /etc/hosts:/overridden_Worker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 54 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config.yaml

@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/envs-are-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/envs-are-missing.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/security-principal-is-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/security-principal-is-missing.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/tensorboard-dockerimage-is-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/tensorboard-dockerimage-is-missing.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/valid-config-with-overrides.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config-with-overrides.yaml

@@ -17,6 +17,7 @@
 spec:
 spec:
   name: testJobName
   name: testJobName
   job_type: testJobType
   job_type: testJobType
+  framework: tensorflow
 
 
 configs:
 configs:
   input_path: testInputPath
   input_path: testInputPath

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: tensorflow
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 17 - 5
hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyJobSubmitter.java

@@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
 
 
 import java.io.File;
 import java.io.File;
@@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
   }
   }
 
 
   @Override
   @Override
-  public ApplicationId submitJob(RunJobParameters parameters)
-      throws IOException, YarnException {
+  public ApplicationId submitJob(ParametersHolder parameters)
+      throws IOException {
+    if (parameters.getFramework() == Framework.PYTORCH) {
+      // we need to throw an exception, as ParametersHolder's parameters field
+      // could not be casted to TensorFlowRunJobParameters.
+      throw new UnsupportedOperationException(
+          "Support \"–-framework\" option for PyTorch in Tony is coming. " +
+              "Please check the documentation about how to submit a " +
+              "PyTorch job with TonY runtime.");
+    }
+
     LOG.info("Starting Tony runtime..");
     LOG.info("Starting Tony runtime..");
 
 
     File tonyFinalConfPath = File.createTempFile("temp",
     File tonyFinalConfPath = File.createTempFile("temp",
         Constants.TONY_FINAL_XML);
         Constants.TONY_FINAL_XML);
     // Write user's overridden conf to an xml to be localized.
     // Write user's overridden conf to an xml to be localized.
-    Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters);
+    Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
+        (TensorFlowRunJobParameters) parameters.getParameters());
     try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
     try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
       tonyConf.writeXml(os);
       tonyConf.writeXml(os);
     } catch (IOException e) {
     } catch (IOException e) {
@@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
       LOG.error("Failed to init TonyClient: ", e);
       LOG.error("Failed to init TonyClient: ", e);
     }
     }
     Thread clientThread = new Thread(tonyClient::start);
     Thread clientThread = new Thread(tonyClient::start);
-    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
+    java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
       try {
       try {
         tonyClient.forceKillApplication();
         tonyClient.forceKillApplication();
       } catch (YarnException | IOException e) {
       } catch (YarnException | IOException e) {

+ 2 - 2
hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyUtils.java

@@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
 import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Arrays;
@@ -35,7 +35,7 @@ public final class TonyUtils {
   private static final Log LOG = LogFactory.getLog(TonyUtils.class);
   private static final Log LOG = LogFactory.getLog(TonyUtils.class);
 
 
   public static Configuration tonyConfFromClientContext(
   public static Configuration tonyConfFromClientContext(
-      RunJobParameters parameters) {
+      TensorFlowRunJobParameters parameters) {
     Configuration tonyConf = new Configuration();
     Configuration tonyConf = new Configuration();
     tonyConf.setInt(
     tonyConf.setInt(
         TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),
         TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),

+ 4 - 0
hadoop-submarine/hadoop-submarine-tony-runtime/src/site/markdown/QuickStart.md

@@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --num_workers 2 \
  --num_workers 2 \
  --worker_resources memory=3G,vcores=2 \
  --worker_resources memory=3G,vcores=2 \
  --num_ps 2 \
  --num_ps 2 \
@@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --worker_resources memory=3G,vcores=2 \
  --worker_resources memory=3G,vcores=2 \
@@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --num_workers 2 \
  --num_workers 2 \
  --worker_resources memory=3G,vcores=2 \
  --worker_resources memory=3G,vcores=2 \
  --num_ps 2 \
  --num_ps 2 \
@@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --worker_resources memory=3G,vcores=2 \
  --worker_resources memory=3G,vcores=2 \

+ 19 - 7
hadoop-submarine/hadoop-submarine-tony-runtime/src/test/java/TestTonyUtils.java

@@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.MockClientContext;
 import org.apache.hadoop.yarn.submarine.common.MockClientContext;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
@@ -31,6 +33,7 @@ import org.junit.Test;
 
 
 import java.io.IOException;
 import java.io.IOException;
 
 
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
@@ -59,7 +62,8 @@ public class TestTonyUtils {
       throws IOException, YarnException {
       throws IOException, YarnException {
     MockClientContext mockClientContext = new MockClientContext();
     MockClientContext mockClientContext = new MockClientContext();
     JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
     JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
+    when(mockJobSubmitter.submitJob(
+        any(ParametersHolder.class))).thenReturn(
         ApplicationId.newInstance(1234L, 1));
         ApplicationId.newInstance(1234L, 1));
     JobMonitor mockJobMonitor = mock(JobMonitor.class);
     JobMonitor mockJobMonitor = mock(JobMonitor.class);
     SubmarineStorage storage = mock(SubmarineStorage.class);
     SubmarineStorage storage = mock(SubmarineStorage.class);
@@ -82,20 +86,28 @@ public class TestTonyUtils {
   public void testTonyConfFromClientContext() throws Exception {
   public void testTonyConfFromClientContext() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
     runJobCli.run(
     runJobCli.run(
-        new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+        new String[] {"--framework", "tensorflow", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
             "--input_path", "hdfs://input",
             "--input_path", "hdfs://input",
             "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
             "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
             "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
             "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
             "--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
             "--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
             "python run-ps.py"});
             "python run-ps.py"});
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
     Configuration tonyConf = TonyUtils
     Configuration tonyConf = TonyUtils
-        .tonyConfFromClientContext(jobRunParameters);
+        .tonyConfFromClientContext(tensorFlowParams);
     Assert.assertEquals(jobRunParameters.getDockerImageName(),
     Assert.assertEquals(jobRunParameters.getDockerImageName(),
         tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
         tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
     Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
     Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
         .getInstancesKey("worker")));
         .getInstancesKey("worker")));
-    Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(),
+    Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
         tonyConf.get(TonyConfigurationKeys
         tonyConf.get(TonyConfigurationKeys
             .getExecuteCommandKey("worker")));
             .getExecuteCommandKey("worker")));
     Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
     Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
@@ -107,7 +119,7 @@ public class TestTonyUtils {
     Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
     Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
         .getResourceKey(Constants.PS_JOB_NAME,
         .getResourceKey(Constants.PS_JOB_NAME,
         Constants.VCORES)));
         Constants.VCORES)));
-    Assert.assertEquals(jobRunParameters.getPSLaunchCmd(),
+    Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
         tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
         tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
   }
   }
 }
 }

+ 49 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java

@@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
@@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
 import java.io.IOException;
 import java.io.IOException;
 import java.util.Objects;
 import java.util.Objects;
 
 
+import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
 
 
 /**
 /**
  * Abstract base class for Component classes.
  * Abstract base class for Component classes.
@@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
 public abstract class AbstractComponent {
 public abstract class AbstractComponent {
   private final FileSystemOperations fsOperations;
   private final FileSystemOperations fsOperations;
   protected final RunJobParameters parameters;
   protected final RunJobParameters parameters;
-  protected final TaskType taskType;
+  protected final Role role;
   private final RemoteDirectoryManager remoteDirectoryManager;
   private final RemoteDirectoryManager remoteDirectoryManager;
   protected final Configuration yarnConfig;
   protected final Configuration yarnConfig;
   private final LaunchCommandFactory launchCommandFactory;
   private final LaunchCommandFactory launchCommandFactory;
@@ -52,19 +58,55 @@ public abstract class AbstractComponent {
 
 
   public AbstractComponent(FileSystemOperations fsOperations,
   public AbstractComponent(FileSystemOperations fsOperations,
       RemoteDirectoryManager remoteDirectoryManager,
       RemoteDirectoryManager remoteDirectoryManager,
-      RunJobParameters parameters, TaskType taskType,
+      RunJobParameters parameters, Role role,
       Configuration yarnConfig,
       Configuration yarnConfig,
       LaunchCommandFactory launchCommandFactory) {
       LaunchCommandFactory launchCommandFactory) {
     this.fsOperations = fsOperations;
     this.fsOperations = fsOperations;
     this.remoteDirectoryManager = remoteDirectoryManager;
     this.remoteDirectoryManager = remoteDirectoryManager;
     this.parameters = parameters;
     this.parameters = parameters;
-    this.taskType = taskType;
+    this.role = role;
     this.launchCommandFactory = launchCommandFactory;
     this.launchCommandFactory = launchCommandFactory;
     this.yarnConfig = yarnConfig;
     this.yarnConfig = yarnConfig;
   }
   }
 
 
   protected abstract Component createComponent() throws IOException;
   protected abstract Component createComponent() throws IOException;
 
 
+  protected Component createComponentInternal() throws IOException {
+    Objects.requireNonNull(this.parameters.getWorkerResource(),
+        "Worker resource must not be null!");
+    if (parameters.getNumWorkers() < 1) {
+      throw new IllegalArgumentException(
+          "Number of workers should be at least 1!");
+    }
+
+    Component component = new Component();
+    component.setName(role.getComponentName());
+
+    if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
+        role.equals(PyTorchRole.PRIMARY_WORKER)) {
+      component.setNumberOfContainers(1L);
+      component.getConfiguration().setProperty(
+          CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
+    } else {
+      component.setNumberOfContainers(
+          (long) parameters.getNumWorkers() - 1);
+    }
+
+    if (parameters.getWorkerDockerImage() != null) {
+      component.setArtifact(
+          getDockerArtifact(parameters.getWorkerDockerImage()));
+    }
+
+    component.setResource(convertYarnResourceToServiceResource(
+        parameters.getWorkerResource()));
+    component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
+
+    addCommonEnvironments(component, role);
+    generateLaunchCommand(component);
+
+    return component;
+  }
+
   /**
   /**
    * Generates a command launch script on local disk,
    * Generates a command launch script on local disk,
    * returns path to the script.
    * returns path to the script.
@@ -72,7 +114,7 @@ public abstract class AbstractComponent {
   protected void generateLaunchCommand(Component component)
   protected void generateLaunchCommand(Component component)
       throws IOException {
       throws IOException {
     AbstractLaunchCommand launchCommand =
     AbstractLaunchCommand launchCommand =
-        launchCommandFactory.createLaunchCommand(taskType, component);
+        launchCommandFactory.createLaunchCommand(role, component);
     this.localScriptFile = launchCommand.generateLaunchScript();
     this.localScriptFile = launchCommand.generateLaunchScript();
 
 
     String remoteLaunchCommand = uploadLaunchCommand(component);
     String remoteLaunchCommand = uploadLaunchCommand(component);
@@ -86,7 +128,7 @@ public abstract class AbstractComponent {
     Path stagingDir =
     Path stagingDir =
         remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
         remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
 
 
-    String destScriptFileName = getScriptFileName(taskType);
+    String destScriptFileName = getScriptFileName(role);
     fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
     fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
         localScriptFile, destScriptFileName, component);
         localScriptFile, destScriptFileName, component);
 
 

+ 167 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractServiceSpec.java

@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
+
+/**
+ * Abstract base class that supports creating service specs for Native Service.
+ */
+public abstract class AbstractServiceSpec implements ServiceSpec {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(AbstractServiceSpec.class);
+  protected final RunJobParameters parameters;
+  protected final FileSystemOperations fsOperations;
+  private final Localizer localizer;
+  protected final RemoteDirectoryManager remoteDirectoryManager;
+  protected final Configuration yarnConfig;
+  protected final LaunchCommandFactory launchCommandFactory;
+  private final WorkerComponentFactory workerFactory;
+
+  public AbstractServiceSpec(RunJobParameters parameters,
+      ClientContext clientContext, FileSystemOperations fsOperations,
+      LaunchCommandFactory launchCommandFactory,
+      Localizer localizer) {
+    this.parameters = parameters;
+    this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
+    this.yarnConfig = clientContext.getYarnConfig();
+    this.fsOperations = fsOperations;
+    this.localizer = localizer;
+    this.launchCommandFactory = launchCommandFactory;
+    this.workerFactory = new WorkerComponentFactory(fsOperations,
+        remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
+  }
+
+  protected ServiceWrapper createServiceSpecWrapper() throws IOException {
+    Service serviceSpec = new Service();
+    serviceSpec.setName(parameters.getName());
+    serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
+    serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
+
+    KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
+        .create(fsOperations, remoteDirectoryManager, parameters);
+    if (kerberosPrincipal != null) {
+      serviceSpec.setKerberosPrincipal(kerberosPrincipal);
+    }
+
+    handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
+    localizer.handleLocalizations(serviceSpec);
+    return new ServiceWrapper(serviceSpec);
+  }
+
+
+  // Handle worker and primary_worker.
+  protected void addWorkerComponents(ServiceWrapper serviceWrapper,
+      Framework framework)
+      throws IOException {
+    final Role primaryWorkerRole;
+    final Role workerRole;
+    if (framework == Framework.TENSORFLOW) {
+      primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
+      workerRole = TensorFlowRole.WORKER;
+    } else {
+      primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
+      workerRole = PyTorchRole.WORKER;
+    }
+
+    addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
+
+    if (parameters.getNumWorkers() > 1) {
+      addWorkerComponent(serviceWrapper, workerRole, framework);
+    }
+  }
+  private void addWorkerComponent(ServiceWrapper serviceWrapper,
+      Role role, Framework framework) throws IOException {
+    AbstractComponent component = workerFactory.create(framework, role);
+    serviceWrapper.addComponent(component);
+  }
+
+  protected void handleQuicklinks(Service serviceSpec)
+      throws IOException {
+    List<Quicklink> quicklinks = parameters.getQuicklinks();
+    if (quicklinks != null && !quicklinks.isEmpty()) {
+      for (Quicklink ql : quicklinks) {
+        // Make sure it is a valid instance name
+        String instanceName = ql.getComponentInstanceName();
+        boolean found = false;
+
+        for (Component comp : serviceSpec.getComponents()) {
+          for (int i = 0; i < comp.getNumberOfContainers(); i++) {
+            String possibleInstanceName = comp.getName() + "-" + i;
+            if (possibleInstanceName.equals(instanceName)) {
+              found = true;
+              break;
+            }
+          }
+        }
+
+        if (!found) {
+          throw new IOException(
+              "Couldn't find a component instance = " + instanceName
+                  + " while adding quicklink");
+        }
+
+        String link = ql.getProtocol()
+            + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
+            getUserName(), getDNSDomain(yarnConfig), ql.getPort());
+        addQuicklink(serviceSpec, ql.getLabel(), link);
+      }
+    }
+  }
+
+  protected static void addQuicklink(Service serviceSpec, String label,
+      String link) {
+    Map<String, String> quicklinks = serviceSpec.getQuicklinks();
+    if (quicklinks == null) {
+      quicklinks = new HashMap<>();
+      serviceSpec.setQuicklinks(quicklinks);
+    }
+
+    if (SubmarineLogs.isVerbose()) {
+      LOG.info("Added quicklink, " + label + "=" + link);
+    }
+
+    quicklinks.put(label, link);
+  }
+}

+ 10 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java

@@ -37,6 +37,7 @@ import java.io.File;
 import java.io.FileNotFoundException;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.IOException;
 import java.util.HashSet;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.Set;
 
 
 /**
 /**
@@ -195,6 +196,15 @@ public class FileSystemOperations {
     fs.setPermission(destPath, new FsPermission(permission));
     fs.setPermission(destPath, new FsPermission(permission));
   }
   }
 
 
+  public static boolean needHdfs(List<String> stringsToCheck) {
+    for (String content : stringsToCheck) {
+      if (content != null && content.contains("hdfs://")) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public static boolean needHdfs(String content) {
   public static boolean needHdfs(String content) {
     return content != null && content.contains("hdfs://");
     return content != null && content.contains("hdfs://");
   }
   }

+ 20 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java

@@ -16,9 +16,10 @@
 
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
 
 
+import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
@@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
 import java.io.File;
 import java.io.File;
 import java.io.IOException;
 import java.io.IOException;
 import java.io.PrintWriter;
 import java.io.PrintWriter;
+import java.util.List;
+import java.util.Objects;
 
 
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
 import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
 import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
@@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
   }
   }
 
 
   private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
   private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
-    return needHdfs(parameters.getInputPath()) ||
-        needHdfs(parameters.getPSLaunchCmd()) ||
-        needHdfs(parameters.getWorkerLaunchCmd()) ||
-        hadoopEnv;
+    List<String> launchCommands = parameters.getLaunchCommands();
+    if (launchCommands != null) {
+      launchCommands.removeIf(Objects::isNull);
+    }
+
+    ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
+
+    if (launchCommands != null && !launchCommands.isEmpty()) {
+      listBuilder.addAll(launchCommands);
+    }
+    if (parameters.getInputPath() != null) {
+      listBuilder.add(parameters.getInputPath());
+    }
+    List<String> stringsToCheck = listBuilder.build();
+
+    return needHdfs(stringsToCheck) || hadoopEnv;
   }
   }
 
 
   private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
   private void appendHdfsHome(PrintWriter fw, String hdfsHome) {

+ 1 - 1
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java

@@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
         "instantiated!");
         "instantiated!");
   }
   }
 
 
-  static String generateJson(Service service) throws IOException {
+  public static String generateJson(Service service) throws IOException {
     File serviceSpecFile = File.createTempFile(service.getName(), ".json");
     File serviceSpecFile = File.createTempFile(service.getName(), ".json");
     String buffer = jsonSerDeser.toJson(service);
     String buffer = jsonSerDeser.toJson(service);
     Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
     Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),

+ 71 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/WorkerComponentFactory.java

@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
+
+/**
+ * Factory class that helps creating Native Service components.
+ */
+public class WorkerComponentFactory {
+  private final FileSystemOperations fsOperations;
+  private final RemoteDirectoryManager remoteDirectoryManager;
+  private final RunJobParameters parameters;
+  private final LaunchCommandFactory launchCommandFactory;
+  private final Configuration yarnConfig;
+
+  WorkerComponentFactory(FileSystemOperations fsOperations,
+      RemoteDirectoryManager remoteDirectoryManager,
+      RunJobParameters parameters,
+      LaunchCommandFactory launchCommandFactory,
+      Configuration yarnConfig) {
+    this.fsOperations = fsOperations;
+    this.remoteDirectoryManager = remoteDirectoryManager;
+    this.parameters = parameters;
+    this.launchCommandFactory = launchCommandFactory;
+    this.yarnConfig = yarnConfig;
+  }
+
+  /**
+   * Creates either a TensorFlow or a PyTorch Native Service component.
+   */
+  public AbstractComponent create(Framework framework, Role role) {
+    if (framework == Framework.TENSORFLOW) {
+      return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
+          (TensorFlowRunJobParameters) parameters, role,
+          (TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
+    } else if (framework == Framework.PYTORCH) {
+      return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
+          (PyTorchRunJobParameters) parameters, role,
+          (PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
+    } else {
+      throw new UnsupportedOperationException("Only supported frameworks are: "
+          + Framework.getValues());
+    }
+  }
+}

+ 62 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java

@@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.service.api.records.Service;
 import org.apache.hadoop.yarn.service.api.records.Service;
 import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
 import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
@@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.io.IOException;
 
 
 import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
 import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
+import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
 
 
 /**
 /**
  * Submit a job to cluster.
  * Submit a job to cluster.
@@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
    * {@inheritDoc}
    * {@inheritDoc}
    */
    */
   @Override
   @Override
-  public ApplicationId submitJob(RunJobParameters parameters)
+  public ApplicationId submitJob(ParametersHolder paramsHolder)
       throws IOException, YarnException {
       throws IOException, YarnException {
+    Framework framework = paramsHolder.getFramework();
+    RunJobParameters parameters =
+        (RunJobParameters) paramsHolder.getParameters();
+
+    if (framework == Framework.TENSORFLOW) {
+      return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
+    } else if (framework == Framework.PYTORCH) {
+      return submitPyTorchJob((PyTorchRunJobParameters) parameters);
+    } else {
+      throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
+    }
+  }
+
+  private ApplicationId submitTensorFlowJob(
+      TensorFlowRunJobParameters parameters) throws IOException, YarnException {
     FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
     FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
     HadoopEnvironmentSetup hadoopEnvSetup =
     HadoopEnvironmentSetup hadoopEnvSetup =
         new HadoopEnvironmentSetup(clientContext, fsOperations);
         new HadoopEnvironmentSetup(clientContext, fsOperations);
 
 
     Service serviceSpec = createTensorFlowServiceSpec(parameters,
     Service serviceSpec = createTensorFlowServiceSpec(parameters,
         fsOperations, hadoopEnvSetup);
         fsOperations, hadoopEnvSetup);
+    return submitJobInternal(serviceSpec);
+  }
+
+  private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
+      throws IOException, YarnException {
+    FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
+    HadoopEnvironmentSetup hadoopEnvSetup =
+        new HadoopEnvironmentSetup(clientContext, fsOperations);
+
+    Service serviceSpec = createPyTorchServiceSpec(parameters,
+        fsOperations, hadoopEnvSetup);
+    return submitJobInternal(serviceSpec);
+  }
+
+  private ApplicationId submitJobInternal(Service serviceSpec)
+      throws IOException, YarnException {
     String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
     String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
 
 
     AppAdminClient appAdminClient =
     AppAdminClient appAdminClient =
@@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
           "Fail to launch application with exit code:" + code);
           "Fail to launch application with exit code:" + code);
     }
     }
 
 
-    String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
+    String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
     Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
     Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
 
 
     // Retry multiple times if applicationId is null
     // Retry multiple times if applicationId is null
@@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
     return appid;
     return appid;
   }
   }
 
 
-  private Service createTensorFlowServiceSpec(RunJobParameters parameters,
+  private Service createTensorFlowServiceSpec(
+      TensorFlowRunJobParameters parameters,
       FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
       FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
       throws IOException {
       throws IOException {
-    LaunchCommandFactory launchCommandFactory =
-        new LaunchCommandFactory(hadoopEnvSetup, parameters,
+    TensorFlowLaunchCommandFactory launchCommandFactory =
+        new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
             clientContext.getYarnConfig());
             clientContext.getYarnConfig());
     Localizer localizer = new Localizer(fsOperations,
     Localizer localizer = new Localizer(fsOperations,
         clientContext.getRemoteDirectoryManager(), parameters);
         clientContext.getRemoteDirectoryManager(), parameters);
@@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
     return serviceWrapper.getService();
     return serviceWrapper.getService();
   }
   }
 
 
+  private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
+      FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
+      throws IOException {
+    PyTorchLaunchCommandFactory launchCommandFactory =
+        new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
+            clientContext.getYarnConfig());
+    Localizer localizer = new Localizer(fsOperations,
+        clientContext.getRemoteDirectoryManager(), parameters);
+    PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
+        parameters, this.clientContext, fsOperations, launchCommandFactory,
+        localizer);
+
+    serviceWrapper = pyTorchServiceSpec.create();
+    return serviceWrapper.getService();
+  }
+
   @VisibleForTesting
   @VisibleForTesting
   public ServiceWrapper getServiceWrapper() {
   public ServiceWrapper getServiceWrapper() {
     return serviceWrapper;
     return serviceWrapper;

+ 4 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java

@@ -17,11 +17,9 @@
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
 
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import java.io.IOException;
 import java.io.IOException;
-import java.util.Objects;
 
 
 /**
 /**
  * Abstract base class for Launch command implementations for Services.
  * Abstract base class for Launch command implementations for Services.
@@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
   private final LaunchScriptBuilder builder;
   private final LaunchScriptBuilder builder;
 
 
   public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
   public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters)
-      throws IOException {
-    Objects.requireNonNull(taskType, "TaskType must not be null!");
-    this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
+      Component component, RunJobParameters parameters,
+      String launchCommandPrefix) throws IOException {
+    this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
         parameters, component);
         parameters, component);
   }
   }
 
 

+ 5 - 42
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java

@@ -16,52 +16,15 @@
 
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
 
-import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 
 
 import java.io.IOException;
 import java.io.IOException;
-import java.util.Objects;
 
 
 /**
 /**
- * Simple factory to create instances of {@link AbstractLaunchCommand}
- * based on the {@link TaskType}.
- * All dependencies are passed to this factory that could be required
- * by any implementor of {@link AbstractLaunchCommand}.
+ * Interface for creating launch commands.
  */
  */
-public class LaunchCommandFactory {
-  private final HadoopEnvironmentSetup hadoopEnvSetup;
-  private final RunJobParameters parameters;
-  private final Configuration yarnConfig;
-
-  public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
-      RunJobParameters parameters, Configuration yarnConfig) {
-    this.hadoopEnvSetup = hadoopEnvSetup;
-    this.parameters = parameters;
-    this.yarnConfig = yarnConfig;
-  }
-
-  public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
-      Component component) throws IOException {
-    Objects.requireNonNull(taskType, "TaskType must not be null!");
-
-    if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
-      return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
-          component, parameters, yarnConfig);
-
-    } else if (taskType == TaskType.PS) {
-      return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
-          parameters, yarnConfig);
-
-    } else if (taskType == TaskType.TENSORBOARD) {
-      return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
-          parameters);
-    }
-    throw new IllegalStateException("Unknown task type: " + taskType);
-  }
+public interface LaunchCommandFactory {
+  AbstractLaunchCommand createLaunchCommand(Role role, Component component)
+      throws IOException;
 }
 }

+ 4 - 3
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java

@@ -17,7 +17,7 @@
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
 
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
@@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
   private final StringBuilder scriptBuffer;
   private final StringBuilder scriptBuffer;
   private String launchCommand;
   private String launchCommand;
 
 
-  LaunchScriptBuilder(String namePrefix,
+  LaunchScriptBuilder(String launchScriptPrefix,
       HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
       HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
       Component component) throws IOException {
       Component component) throws IOException {
-    this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
+    this.file = File.createTempFile(launchScriptPrefix +
+        "-launch-script", ".sh");
     this.hadoopEnvSetup = hadoopEnvSetup;
     this.hadoopEnvSetup = hadoopEnvSetup;
     this.parameters = parameters;
     this.parameters = parameters;
     this.component = component;
     this.component = component;

+ 61 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/PyTorchLaunchCommandFactory.java

@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
+
+/**
+ * Simple factory to create instances of {@link AbstractLaunchCommand}
+ * based on the {@link Role}.
+ * All dependencies are passed to this factory that could be required
+ * by any implementor of {@link AbstractLaunchCommand}.
+ */
+public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
+  private final HadoopEnvironmentSetup hadoopEnvSetup;
+  private final PyTorchRunJobParameters parameters;
+  private final Configuration yarnConfig;
+
+  public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
+      PyTorchRunJobParameters parameters, Configuration yarnConfig) {
+    this.hadoopEnvSetup = hadoopEnvSetup;
+    this.parameters = parameters;
+    this.yarnConfig = yarnConfig;
+  }
+
+  public AbstractLaunchCommand createLaunchCommand(Role role,
+      Component component) throws IOException {
+    Objects.requireNonNull(role, "Role must not be null!");
+
+    if (role == PyTorchRole.WORKER ||
+        role == PyTorchRole.PRIMARY_WORKER) {
+      return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
+          component, parameters, yarnConfig);
+
+    } else {
+      throw new IllegalStateException("Unknown task type: " + role);
+    }
+  }
+}

+ 70 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TensorFlowLaunchCommandFactory.java

@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Simple factory to create instances of {@link AbstractLaunchCommand}
+ * based on the {@link Role}.
+ * All dependencies are passed to this factory that could be required
+ * by any implementor of {@link AbstractLaunchCommand}.
+ */
+public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
+  private final HadoopEnvironmentSetup hadoopEnvSetup;
+  private final TensorFlowRunJobParameters parameters;
+  private final Configuration yarnConfig;
+
+  public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
+      TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
+    this.hadoopEnvSetup = hadoopEnvSetup;
+    this.parameters = parameters;
+    this.yarnConfig = yarnConfig;
+  }
+
+  @Override
+  public AbstractLaunchCommand createLaunchCommand(Role role,
+      Component component) throws IOException {
+    Objects.requireNonNull(role, "Role must not be null!");
+
+    if (role == TensorFlowRole.WORKER ||
+        role == TensorFlowRole.PRIMARY_WORKER) {
+      return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
+          component, parameters, yarnConfig);
+
+    } else if (role == TensorFlowRole.PS) {
+      return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
+          parameters, yarnConfig);
+
+    } else if (role == TensorFlowRole.TENSORBOARD) {
+      return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
+          parameters);
+    }
+    throw new IllegalStateException("Unknown task type: " + role);
+  }
+}

+ 68 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/PyTorchServiceSpec.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
+
+import java.io.IOException;
+
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class contains all the logic to create an instance
+ * of a {@link Service} object for PyTorch.
+ * Please note that currently, only single-node (non-distributed)
+ * support is implemented for PyTorch.
+ */
+public class PyTorchServiceSpec extends AbstractServiceSpec {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(PyTorchServiceSpec.class);
+  //this field is needed in the future!
+  private final PyTorchRunJobParameters pyTorchParameters;
+
+  public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
+      ClientContext clientContext, FileSystemOperations fsOperations,
+      PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
+    super(parameters, clientContext, fsOperations, launchCommandFactory,
+        localizer);
+    this.pyTorchParameters = parameters;
+  }
+
+  @Override
+  public ServiceWrapper create() throws IOException {
+    LOG.info("Creating PyTorch service spec");
+    ServiceWrapper serviceWrapper = createServiceSpecWrapper();
+
+    if (parameters.getNumWorkers() > 0) {
+      addWorkerComponents(serviceWrapper, Framework.PYTORCH);
+    }
+
+    // After all components added, handle quicklinks
+    handleQuicklinks(serviceWrapper.getService());
+
+    return serviceWrapper;
+  }
+
+}

+ 87 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/PyTorchWorkerLaunchCommand.java

@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
+
+import java.io.IOException;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Launch command implementation for PyTorch components.
+ */
+public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
+  private final Configuration yarnConfig;
+  private final boolean distributed;
+  private final int numberOfWorkers;
+  private final String name;
+  private final Role role;
+  private final String launchCommand;
+
+  public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+      Role role, Component component,
+      PyTorchRunJobParameters parameters,
+      Configuration yarnConfig) throws IOException {
+    super(hadoopEnvSetup, component, parameters, role.getName());
+    this.role = role;
+    this.name = parameters.getName();
+    this.distributed = parameters.isDistributed();
+    this.numberOfWorkers = parameters.getNumWorkers();
+    this.yarnConfig = yarnConfig;
+    logReceivedParameters();
+
+    this.launchCommand = parameters.getWorkerLaunchCmd();
+
+    if (StringUtils.isEmpty(this.launchCommand)) {
+      throw new IllegalArgumentException("LaunchCommand must not be null " +
+          "or empty!");
+    }
+  }
+
+  private void logReceivedParameters() {
+    if (this.numberOfWorkers <= 0) {
+      LOG.warn("Received number of workers: {}", this.numberOfWorkers);
+    }
+  }
+
+  @Override
+  public String generateLaunchScript() throws IOException {
+    LaunchScriptBuilder builder = getBuilder();
+    return builder
+        .withLaunchCommand(createLaunchCommand())
+        .build();
+  }
+
+  @Override
+  public String createLaunchCommand() {
+    if (SubmarineLogs.isVerbose()) {
+      LOG.info("PyTorch Worker command =[" + launchCommand + "]");
+    }
+    return launchCommand + '\n';
+  }
+}

+ 19 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/package-info.java

@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate PyTorch launch commands.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;

+ 47 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/PyTorchWorkerComponent.java

@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+
+import java.io.IOException;
+
+/**
+ * Component implementation for Worker process of PyTorch.
+ */
+public class PyTorchWorkerComponent extends AbstractComponent {
+  public PyTorchWorkerComponent(FileSystemOperations fsOperations,
+      RemoteDirectoryManager remoteDirectoryManager,
+      PyTorchRunJobParameters parameters, Role role,
+      PyTorchLaunchCommandFactory launchCommandFactory,
+      Configuration yarnConfig) {
+    super(fsOperations, remoteDirectoryManager, parameters, role,
+        yarnConfig, launchCommandFactory);
+  }
+
+  @Override
+  public Component createComponent() throws IOException {
+    return createComponentInternal();
+  }
+}

+ 20 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * PyTorch Native Service components.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;

+ 20 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * PyTorch-related Native Service runtime artifacts.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;

+ 5 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java

@@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
 import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.submarine.common.Envs;
 import org.apache.hadoop.yarn.submarine.common.Envs;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
 
 
 import java.util.Map;
 import java.util.Map;
@@ -35,10 +35,10 @@ public final class TensorFlowCommons {
   }
   }
 
 
   public static void addCommonEnvironments(Component component,
   public static void addCommonEnvironments(Component component,
-      TaskType taskType) {
+      Role role) {
     Map<String, String> envs = component.getConfiguration().getEnv();
     Map<String, String> envs = component.getConfiguration().getEnv();
     envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
     envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
-    envs.put(Envs.TASK_TYPE_ENV, taskType.name());
+    envs.put(Envs.TASK_TYPE_ENV, role.getName());
   }
   }
 
 
   public static String getUserName() {
   public static String getUserName() {
@@ -49,8 +49,8 @@ public final class TensorFlowCommons {
     return yarnConfig.get("hadoop.registry.dns.domain-name");
     return yarnConfig.get("hadoop.registry.dns.domain-name");
   }
   }
 
 
-  public static String getScriptFileName(TaskType taskType) {
-    return "run-" + taskType.name() + ".sh";
+  public static String getScriptFileName(Role role) {
+    return "run-" + role.getName() + ".sh";
   }
   }
 
 
   public static String getTFConfigEnv(String componentName, int nWorkers,
   public static String getTFConfigEnv(String componentName, int nWorkers,

+ 22 - 125
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java

@@ -16,39 +16,24 @@
 
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
 
 
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
 import org.apache.hadoop.yarn.service.api.records.Service;
 import org.apache.hadoop.yarn.service.api.records.Service;
-import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
-import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
 import java.io.IOException;
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 
 
-import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
-import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
-import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
-import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
 
 
 /**
 /**
  * This class contains all the logic to create an instance
  * This class contains all the logic to create an instance
@@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
  * Worker,PS and Tensorboard components are added to the Service
  * Worker,PS and Tensorboard components are added to the Service
  * based on the value of the received {@link RunJobParameters}.
  * based on the value of the received {@link RunJobParameters}.
  */
  */
-public class TensorFlowServiceSpec implements ServiceSpec {
+public class TensorFlowServiceSpec extends AbstractServiceSpec {
   private static final Logger LOG =
   private static final Logger LOG =
       LoggerFactory.getLogger(TensorFlowServiceSpec.class);
       LoggerFactory.getLogger(TensorFlowServiceSpec.class);
+  private final TensorFlowRunJobParameters tensorFlowParameters;
 
 
-  private final RemoteDirectoryManager remoteDirectoryManager;
-
-  private final RunJobParameters parameters;
-  private final Configuration yarnConfig;
-  private final FileSystemOperations fsOperations;
-  private final LaunchCommandFactory launchCommandFactory;
-  private final Localizer localizer;
-
-  public TensorFlowServiceSpec(RunJobParameters parameters,
+  public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
       ClientContext clientContext, FileSystemOperations fsOperations,
       ClientContext clientContext, FileSystemOperations fsOperations,
-      LaunchCommandFactory launchCommandFactory, Localizer localizer) {
-    this.parameters = parameters;
-    this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
-    this.yarnConfig = clientContext.getYarnConfig();
-    this.fsOperations = fsOperations;
-    this.launchCommandFactory = launchCommandFactory;
-    this.localizer = localizer;
+      TensorFlowLaunchCommandFactory launchCommandFactory,
+      Localizer localizer) {
+    super(parameters, clientContext, fsOperations, launchCommandFactory,
+        localizer);
+    this.tensorFlowParameters = parameters;
   }
   }
 
 
   @Override
   @Override
   public ServiceWrapper create() throws IOException {
   public ServiceWrapper create() throws IOException {
+    LOG.info("Creating TensorFlow service spec");
     ServiceWrapper serviceWrapper = createServiceSpecWrapper();
     ServiceWrapper serviceWrapper = createServiceSpecWrapper();
 
 
-    if (parameters.getNumWorkers() > 0) {
-      addWorkerComponents(serviceWrapper);
+    if (tensorFlowParameters.getNumWorkers() > 0) {
+      addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
     }
     }
 
 
-    if (parameters.getNumPS() > 0) {
+    if (tensorFlowParameters.getNumPS() > 0) {
       addPsComponent(serviceWrapper);
       addPsComponent(serviceWrapper);
     }
     }
 
 
-    if (parameters.isTensorboardEnabled()) {
+    if (tensorFlowParameters.isTensorboardEnabled()) {
       createTensorBoardComponent(serviceWrapper);
       createTensorBoardComponent(serviceWrapper);
     }
     }
 
 
@@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
     return serviceWrapper;
     return serviceWrapper;
   }
   }
 
 
-  private ServiceWrapper createServiceSpecWrapper() throws IOException {
-    Service serviceSpec = new Service();
-    serviceSpec.setName(parameters.getName());
-    serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
-    serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
-
-    KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
-        .create(fsOperations, remoteDirectoryManager, parameters);
-    if (kerberosPrincipal != null) {
-      serviceSpec.setKerberosPrincipal(kerberosPrincipal);
-    }
-
-    handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
-    localizer.handleLocalizations(serviceSpec);
-    return new ServiceWrapper(serviceSpec);
-  }
-
   private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
   private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
       throws IOException {
       throws IOException {
     TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
     TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
-        remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
+        remoteDirectoryManager, parameters,
+        (TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
     serviceWrapper.addComponent(tbComponent);
     serviceWrapper.addComponent(tbComponent);
 
 
     addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
     addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
         tbComponent.getTensorboardLink());
         tbComponent.getTensorboardLink());
   }
   }
 
 
-  private static void addQuicklink(Service serviceSpec, String label,
-      String link) {
-    Map<String, String> quicklinks = serviceSpec.getQuicklinks();
-    if (quicklinks == null) {
-      quicklinks = new HashMap<>();
-      serviceSpec.setQuicklinks(quicklinks);
-    }
-
-    if (SubmarineLogs.isVerbose()) {
-      LOG.info("Added quicklink, " + label + "=" + link);
-    }
-
-    quicklinks.put(label, link);
-  }
-
-  private void handleQuicklinks(Service serviceSpec)
-      throws IOException {
-    List<Quicklink> quicklinks = parameters.getQuicklinks();
-    if (quicklinks != null && !quicklinks.isEmpty()) {
-      for (Quicklink ql : quicklinks) {
-        // Make sure it is a valid instance name
-        String instanceName = ql.getComponentInstanceName();
-        boolean found = false;
-
-        for (Component comp : serviceSpec.getComponents()) {
-          for (int i = 0; i < comp.getNumberOfContainers(); i++) {
-            String possibleInstanceName = comp.getName() + "-" + i;
-            if (possibleInstanceName.equals(instanceName)) {
-              found = true;
-              break;
-            }
-          }
-        }
-
-        if (!found) {
-          throw new IOException(
-              "Couldn't find a component instance = " + instanceName
-                  + " while adding quicklink");
-        }
-
-        String link = ql.getProtocol()
-            + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
-                getUserName(), getDNSDomain(yarnConfig), ql.getPort());
-        addQuicklink(serviceSpec, ql.getLabel(), link);
-      }
-    }
-  }
-
-  // Handle worker and primary_worker.
-
-  private void addWorkerComponents(ServiceWrapper serviceWrapper)
-      throws IOException {
-    addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
-
-    if (parameters.getNumWorkers() > 1) {
-      addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
-    }
-  }
-  private void addWorkerComponent(ServiceWrapper serviceWrapper,
-      RunJobParameters parameters, TaskType taskType) throws IOException {
-    serviceWrapper.addComponent(
-        new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
-        parameters, taskType, launchCommandFactory, yarnConfig));
-  }
-
   private void addPsComponent(ServiceWrapper serviceWrapper)
   private void addPsComponent(ServiceWrapper serviceWrapper)
       throws IOException {
       throws IOException {
     serviceWrapper.addComponent(
     serviceWrapper.addComponent(
         new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
         new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
-            launchCommandFactory, parameters, yarnConfig));
+            (TensorFlowLaunchCommandFactory) launchCommandFactory,
+            parameters, yarnConfig));
   }
   }
 
 
 }
 }

+ 4 - 4
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java

@@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
@@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
   private final String checkpointPath;
   private final String checkpointPath;
 
 
   public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
   public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters)
+      Role role, Component component, RunJobParameters parameters)
       throws IOException {
       throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters);
+    super(hadoopEnvSetup, component, parameters, role.getName());
     Objects.requireNonNull(parameters.getCheckpointPath(),
     Objects.requireNonNull(parameters.getCheckpointPath(),
         "CheckpointPath must not be null as it is part "
         "CheckpointPath must not be null as it is part "
             + "of the tensorboard command!");
             + "of the tensorboard command!");

+ 11 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java

@@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
@@ -28,6 +28,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.util.Objects;
 
 
 /**
 /**
  * Launch command implementation for
  * Launch command implementation for
@@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
   private final int numberOfWorkers;
   private final int numberOfWorkers;
   private final int numberOfPS;
   private final int numberOfPS;
   private final String name;
   private final String name;
-  private final TaskType taskType;
+  private final Role role;
 
 
   TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
   TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters,
+      Role role, Component component,
+      TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters);
-    this.taskType = taskType;
+    super(hadoopEnvSetup, component, parameters,
+        role != null ? role.getName(): "");
+    Objects.requireNonNull(role, "TensorFlowRole must not be null!");
+    this.role = role;
     this.name = parameters.getName();
     this.name = parameters.getName();
     this.distributed = parameters.isDistributed();
     this.distributed = parameters.isDistributed();
     this.numberOfWorkers = parameters.getNumWorkers();
     this.numberOfWorkers = parameters.getNumWorkers();
@@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
     // When distributed training is required
     // When distributed training is required
     if (distributed) {
     if (distributed) {
       String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
       String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
-          taskType.getComponentName(), numberOfWorkers,
+          role.getComponentName(), numberOfWorkers,
           numberOfPS, name,
           numberOfPS, name,
           TensorFlowCommons.getUserName(),
           TensorFlowCommons.getUserName(),
           TensorFlowCommons.getDNSDomain(yarnConfig));
           TensorFlowCommons.getDNSDomain(yarnConfig));

+ 5 - 4
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java

@@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
@@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
   private final String launchCommand;
   private final String launchCommand;
 
 
   public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
   public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters,
+      Role role, Component component,
+      TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+    super(hadoopEnvSetup, role, component, parameters, yarnConfig);
     this.launchCommand = parameters.getPSLaunchCmd();
     this.launchCommand = parameters.getPSLaunchCmd();
 
 
     if (StringUtils.isEmpty(this.launchCommand)) {
     if (StringUtils.isEmpty(this.launchCommand)) {

+ 5 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java

@@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
@@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
   private final String launchCommand;
   private final String launchCommand;
 
 
   public TensorFlowWorkerLaunchCommand(
   public TensorFlowWorkerLaunchCommand(
-      HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
-      Component component, RunJobParameters parameters,
+      HadoopEnvironmentSetup hadoopEnvSetup, Role role,
+      Component component, TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+    super(hadoopEnvSetup, role, component, parameters, yarnConfig);
     this.launchCommand = parameters.getWorkerLaunchCmd();
     this.launchCommand = parameters.getWorkerLaunchCmd();
 
 
     if (StringUtils.isEmpty(this.launchCommand)) {
     if (StringUtils.isEmpty(this.launchCommand)) {

+ 16 - 12
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java

@@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
 import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
@@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
   public TensorBoardComponent(FileSystemOperations fsOperations,
   public TensorBoardComponent(FileSystemOperations fsOperations,
       RemoteDirectoryManager remoteDirectoryManager,
       RemoteDirectoryManager remoteDirectoryManager,
       RunJobParameters parameters,
       RunJobParameters parameters,
-      LaunchCommandFactory launchCommandFactory,
+      TensorFlowLaunchCommandFactory launchCommandFactory,
       Configuration yarnConfig) {
       Configuration yarnConfig) {
     super(fsOperations, remoteDirectoryManager, parameters,
     super(fsOperations, remoteDirectoryManager, parameters,
-        TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
+        TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
   }
   }
 
 
   @Override
   @Override
   public Component createComponent() throws IOException {
   public Component createComponent() throws IOException {
-    Objects.requireNonNull(parameters.getTensorboardResource(),
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) this.parameters;
+
+    Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
         "TensorBoard resource must not be null!");
         "TensorBoard resource must not be null!");
 
 
     Component component = new Component();
     Component component = new Component();
-    component.setName(taskType.getComponentName());
+    component.setName(role.getComponentName());
     component.setNumberOfContainers(1L);
     component.setNumberOfContainers(1L);
     component.setRestartPolicy(RestartPolicyEnum.NEVER);
     component.setRestartPolicy(RestartPolicyEnum.NEVER);
     component.setResource(convertYarnResourceToServiceResource(
     component.setResource(convertYarnResourceToServiceResource(
-        parameters.getTensorboardResource()));
+        tensorFlowParams.getTensorboardResource()));
 
 
-    if (parameters.getTensorboardDockerImage() != null) {
+    if (tensorFlowParams.getTensorboardDockerImage() != null) {
       component.setArtifact(
       component.setArtifact(
-          getDockerArtifact(parameters.getTensorboardDockerImage()));
+          getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
     }
     }
 
 
-    addCommonEnvironments(component, taskType);
+    addCommonEnvironments(component, role);
     generateLaunchCommand(component);
     generateLaunchCommand(component);
 
 
     tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
     tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
         parameters.getName(),
         parameters.getName(),
-        taskType.getComponentName() + "-" + 0, getUserName(),
+        role.getComponentName() + "-" + 0, getUserName(),
         getDNSDomain(yarnConfig), DEFAULT_PORT);
         getDNSDomain(yarnConfig), DEFAULT_PORT);
     LOG.info("Link to tensorboard:" + tensorboardLink);
     LOG.info("Link to tensorboard:" + tensorboardLink);
 
 

Some files were not shown because too many files changed in this diff