Browse Source

SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.

Zhankun Tang 6 years ago
parent
commit
36267b6f7c
100 changed files with 4286 additions and 973 deletions
  1. 77 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/base/ubuntu-16.04/Dockerfile.gpu.pytorch_latest
  2. 30 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/build-all.sh
  3. 354 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/cifar10_tutorial.py
  4. 21 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.pytorch_latest
  5. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1
  6. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1
  7. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/build-all.sh
  8. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1
  9. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1
  10. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md
  11. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py
  12. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py
  13. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py
  14. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py
  15. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py
  16. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py
  17. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/Dockerfile.gpu
  18. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/run_container.sh
  19. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/shiro.ini
  20. 0 0
      hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/zeppelin-site.xml
  21. 1 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Cli.java
  22. 2 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliConstants.java
  23. 1 1
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliUtils.java
  24. 24 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Command.java
  25. 6 6
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/ShowJobCli.java
  26. 24 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ConfigType.java
  27. 134 9
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ParametersHolder.java
  28. 120 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/PyTorchRunJobParameters.java
  29. 146 197
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/RunJobParameters.java
  30. 215 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/TensorFlowRunJobParameters.java
  31. 20 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/package-info.java
  32. 9 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/yaml/Spec.java
  33. 59 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/Framework.java
  34. 81 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RoleParameters.java
  35. 123 106
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RunJobCli.java
  36. 19 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/package-info.java
  37. 54 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/PyTorchRole.java
  38. 25 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Role.java
  39. 58 0
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Runtime.java
  40. 11 2
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TensorFlowRole.java
  41. 5 5
      hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/common/JobSubmitter.java
  42. 7 0
      hadoop-submarine/hadoop-submarine-core/src/site/markdown/QuickStart.md
  43. 0 226
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java
  44. 6 5
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/YamlConfigTestUtils.java
  45. 129 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommon.java
  46. 252 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommonYaml.java
  47. 192 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingParameterized.java
  48. 209 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorch.java
  49. 225 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorchYaml.java
  50. 170 0
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlow.java
  51. 45 159
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYaml.java
  52. 8 9
      hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYamlStandalone.java
  53. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/empty-framework.yaml
  54. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/invalid-framework.yaml
  55. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-configs.yaml
  56. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-framework.yaml
  57. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/some-sections-missing.yaml
  58. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/test-false-values.yaml
  59. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-indentation.yaml
  60. 0 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-property-name.yaml
  61. 51 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/envs-are-missing.yaml
  62. 56 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-ps-section.yaml
  63. 57 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-tensorboard-section.yaml
  64. 53 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/security-principal-is-missing.yaml
  65. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config-with-overrides.yaml
  66. 54 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config.yaml
  67. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/envs-are-missing.yaml
  68. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/security-principal-is-missing.yaml
  69. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/tensorboard-dockerimage-is-missing.yaml
  70. 1 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config-with-overrides.yaml
  71. 63 0
      hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config.yaml
  72. 17 5
      hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyJobSubmitter.java
  73. 2 2
      hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyUtils.java
  74. 4 0
      hadoop-submarine/hadoop-submarine-tony-runtime/src/site/markdown/QuickStart.md
  75. 19 7
      hadoop-submarine/hadoop-submarine-tony-runtime/src/test/java/TestTonyUtils.java
  76. 49 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java
  77. 167 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractServiceSpec.java
  78. 10 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java
  79. 20 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java
  80. 1 1
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java
  81. 71 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/WorkerComponentFactory.java
  82. 62 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java
  83. 4 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java
  84. 5 42
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java
  85. 4 3
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java
  86. 61 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/PyTorchLaunchCommandFactory.java
  87. 70 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TensorFlowLaunchCommandFactory.java
  88. 68 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/PyTorchServiceSpec.java
  89. 87 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/PyTorchWorkerLaunchCommand.java
  90. 19 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/package-info.java
  91. 47 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/PyTorchWorkerComponent.java
  92. 20 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/package-info.java
  93. 20 0
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/package-info.java
  94. 5 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java
  95. 22 125
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java
  96. 4 4
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java
  97. 11 7
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
  98. 5 4
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java
  99. 5 5
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java
  100. 16 12
      hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java

+ 77 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/base/ubuntu-16.04/Dockerfile.gpu.pytorch_latest

@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
+ARG PYTHON_VERSION=3.6
+RUN apt-get update && apt-get install -y --no-install-recommends \
+         build-essential \
+         cmake \
+         git \
+         curl \
+         vim \
+         ca-certificates \
+         libjpeg-dev \
+         libpng-dev \
+         wget &&\
+     rm -rf /var/lib/apt/lists/*
+
+
+RUN curl -o ~/miniconda.sh -O  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \
+     chmod +x ~/miniconda.sh && \
+     ~/miniconda.sh -b -p /opt/conda && \
+     rm ~/miniconda.sh && \
+     /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
+     /opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
+     /opt/conda/bin/conda clean -ya
+ENV PATH /opt/conda/bin:$PATH
+RUN pip install ninja
+# This must be done before pip so that requirements.txt is available
+WORKDIR /opt/pytorch
+RUN git clone https://github.com/pytorch/pytorch.git
+WORKDIR pytorch
+RUN git submodule update --init
+RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
+    CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
+    pip install -v .
+
+WORKDIR /opt/pytorch
+RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
+
+WORKDIR /
+# Install Hadoop
+ENV HADOOP_VERSION="3.1.2"
+RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
+RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
+RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
+RUN rm hadoop-${HADOOP_VERSION}.tar.gz
+
+ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
+RUN echo "$LOG_TAG Install java8" && \
+    apt-get update && \
+    apt-get install -y --no-install-recommends openjdk-8-jdk && \
+    apt-get clean && rm -rf /var/lib/apt/lists/*
+
+RUN echo "Install python related packages" && \
+    pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
+    python -m ipykernel.kernelspec
+
+# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
+RUN apt-get update && apt-get install -y --no-install-recommends locales && \
+    apt-get clean && rm -rf /var/lib/apt/lists/*
+RUN locale-gen en_US.UTF-8
+
+
+WORKDIR /workspace
+RUN chmod -R a+w /workspace

+ 30 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/build-all.sh

@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+echo "Building base images"
+
+set -e
+
+cd base/ubuntu-16.04
+
+docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
+
+echo "Finished building base images"
+
+cd ../../with-cifar10-models/ubuntu-16.04
+
+docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1

+ 354 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/cifar10_tutorial.py

@@ -0,0 +1,354 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#    http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# -*- coding: utf-8 -*-
+"""
+Training a Classifier
+=====================
+
+This is it. You have seen how to define neural networks, compute loss and make
+updates to the weights of the network.
+
+Now you might be thinking,
+
+What about data?
+----------------
+
+Generally, when you have to deal with image, text, audio or video data,
+you can use standard python packages that load data into a numpy array.
+Then you can convert this array into a ``torch.*Tensor``.
+
+-  For images, packages such as Pillow, OpenCV are useful
+-  For audio, packages such as scipy and librosa
+-  For text, either raw Python or Cython based loading, or NLTK and
+   SpaCy are useful
+
+Specifically for vision, we have created a package called
+``torchvision``, that has data loaders for common datasets such as
+Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
+``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
+
+This provides a huge convenience and avoids writing boilerplate code.
+
+For this tutorial, we will use the CIFAR10 dataset.
+It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
+‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of
+size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
+
+.. figure:: /_static/img/cifar10.png
+   :alt: cifar10
+
+   cifar10
+
+
+Training an image classifier
+----------------------------
+
+We will do the following steps in order:
+
+1. Load and normalizing the CIFAR10 training and test datasets using
+   ``torchvision``
+2. Define a Convolutional Neural Network
+3. Define a loss function
+4. Train the network on the training data
+5. Test the network on the test data
+
+1. Loading and normalizing CIFAR10
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Using ``torchvision``, it’s extremely easy to load CIFAR10.
+"""
+import torch
+import torchvision
+import torchvision.transforms as transforms
+
+########################################################################
+# The output of torchvision datasets are PILImage images of range [0, 1].
+# We transform them to Tensors of normalized range [-1, 1].
+
+transform = transforms.Compose(
+  [transforms.ToTensor(),
+   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
+                                        download=True, transform=transform)
+trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
+                                          shuffle=True, num_workers=2)
+
+testset = torchvision.datasets.CIFAR10(root='./data', train=False,
+                                       download=True, transform=transform)
+testloader = torch.utils.data.DataLoader(testset, batch_size=4,
+                                         shuffle=False, num_workers=2)
+
+classes = ('plane', 'car', 'bird', 'cat',
+           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
+
+########################################################################
+# Let us show some of the training images, for fun.
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+# functions to show an image
+
+
+def imshow(img):
+  img = img / 2 + 0.5  # unnormalize
+  npimg = img.numpy()
+  plt.imshow(np.transpose(npimg, (1, 2, 0)))
+  plt.show()
+
+
+# get some random training images
+dataiter = iter(trainloader)
+images, labels = dataiter.next()
+
+# show images
+imshow(torchvision.utils.make_grid(images))
+# print labels
+print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
+
+########################################################################
+# 2. Define a Convolutional Neural Network
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# Copy the neural network from the Neural Networks section before and modify it to
+# take 3-channel images (instead of 1-channel images as it was defined).
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Net(nn.Module):
+  def __init__(self):
+    super(Net, self).__init__()
+    self.conv1 = nn.Conv2d(3, 6, 5)
+    self.pool = nn.MaxPool2d(2, 2)
+    self.conv2 = nn.Conv2d(6, 16, 5)
+    self.fc1 = nn.Linear(16 * 5 * 5, 120)
+    self.fc2 = nn.Linear(120, 84)
+    self.fc3 = nn.Linear(84, 10)
+
+  def forward(self, x):
+    x = self.pool(F.relu(self.conv1(x)))
+    x = self.pool(F.relu(self.conv2(x)))
+    x = x.view(-1, 16 * 5 * 5)
+    x = F.relu(self.fc1(x))
+    x = F.relu(self.fc2(x))
+    x = self.fc3(x)
+    return x
+
+
+net = Net()
+
+########################################################################
+# 3. Define a Loss function and optimizer
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# Let's use a Classification Cross-Entropy loss and SGD with momentum.
+
+import torch.optim as optim
+
+criterion = nn.CrossEntropyLoss()
+optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
+
+########################################################################
+# 4. Train the network
+# ^^^^^^^^^^^^^^^^^^^^
+#
+# This is when things start to get interesting.
+# We simply have to loop over our data iterator, and feed the inputs to the
+# network and optimize.
+
+for epoch in range(2):  # loop over the dataset multiple times
+
+  running_loss = 0.0
+  for i, data in enumerate(trainloader, 0):
+    # get the inputs
+    inputs, labels = data
+
+    # zero the parameter gradients
+    optimizer.zero_grad()
+
+    # forward + backward + optimize
+    outputs = net(inputs)
+    loss = criterion(outputs, labels)
+    loss.backward()
+    optimizer.step()
+
+    # print statistics
+    running_loss += loss.item()
+    if i % 2000 == 1999:  # print every 2000 mini-batches
+      print('[%d, %5d] loss: %.3f' %
+            (epoch + 1, i + 1, running_loss / 2000))
+      running_loss = 0.0
+
+print('Finished Training')
+
+########################################################################
+# 5. Test the network on the test data
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+#
+# We have trained the network for 2 passes over the training dataset.
+# But we need to check if the network has learnt anything at all.
+#
+# We will check this by predicting the class label that the neural network
+# outputs, and checking it against the ground-truth. If the prediction is
+# correct, we add the sample to the list of correct predictions.
+#
+# Okay, first step. Let us display an image from the test set to get familiar.
+
+dataiter = iter(testloader)
+images, labels = dataiter.next()
+
+# print images
+imshow(torchvision.utils.make_grid(images))
+print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
+
+########################################################################
+# Okay, now let us see what the neural network thinks these examples above are:
+
+outputs = net(images)
+
+########################################################################
+# The outputs are energies for the 10 classes.
+# The higher the energy for a class, the more the network
+# thinks that the image is of the particular class.
+# So, let's get the index of the highest energy:
+_, predicted = torch.max(outputs, 1)
+
+print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
+                              for j in range(4)))
+
+########################################################################
+# The results seem pretty good.
+#
+# Let us look at how the network performs on the whole dataset.
+
+correct = 0
+total = 0
+with torch.no_grad():
+  for data in testloader:
+    images, labels = data
+    outputs = net(images)
+    _, predicted = torch.max(outputs.data, 1)
+    total += labels.size(0)
+    correct += (predicted == labels).sum().item()
+
+print('Accuracy of the network on the 10000 test images: %d %%' % (
+        100 * correct / total))
+
+########################################################################
+# That looks waaay better than chance, which is 10% accuracy (randomly picking
+# a class out of 10 classes).
+# Seems like the network learnt something.
+#
+# Hmmm, what are the classes that performed well, and the classes that did
+# not perform well:
+
+class_correct = list(0. for i in range(10))
+class_total = list(0. for i in range(10))
+with torch.no_grad():
+  for data in testloader:
+    images, labels = data
+    outputs = net(images)
+    _, predicted = torch.max(outputs, 1)
+    c = (predicted == labels).squeeze()
+    for i in range(4):
+      label = labels[i]
+      class_correct[label] += c[i].item()
+      class_total[label] += 1
+
+for i in range(10):
+  print('Accuracy of %5s : %2d %%' % (
+    classes[i], 100 * class_correct[i] / class_total[i]))
+
+########################################################################
+# Okay, so what next?
+#
+# How do we run these neural networks on the GPU?
+#
+# Training on GPU
+# ----------------
+# Just like how you transfer a Tensor onto the GPU, you transfer the neural
+# net onto the GPU.
+#
+# Let's first define our device as the first visible cuda device if we have
+# CUDA available:
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+# Assuming that we are on a CUDA machine, this should print a CUDA device:
+
+print(device)
+
+########################################################################
+# The rest of this section assumes that ``device`` is a CUDA device.
+#
+# Then these methods will recursively go over all modules and convert their
+# parameters and buffers to CUDA tensors:
+#
+# .. code:: python
+#
+#     net.to(device)
+#
+#
+# Remember that you will have to send the inputs and targets at every step
+# to the GPU too:
+#
+# .. code:: python
+#
+#         inputs, labels = inputs.to(device), labels.to(device)
+#
+# Why dont I notice MASSIVE speedup compared to CPU? Because your network
+# is realllly small.
+#
+# **Exercise:** Try increasing the width of your network (argument 2 of
+# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d`` –
+# they need to be the same number), see what kind of speedup you get.
+#
+# **Goals achieved**:
+#
+# - Understanding PyTorch's Tensor library and neural networks at a high level.
+# - Train a small neural network to classify images
+#
+# Training on multiple GPUs
+# -------------------------
+# If you want to see even more MASSIVE speedup using all of your GPUs,
+# please check out :doc:`data_parallel_tutorial`.
+#
+# Where do I go next?
+# -------------------
+#
+# -  :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
+# -  `Train a state-of-the-art ResNet network on imagenet`_
+# -  `Train a face generator using Generative Adversarial Networks`_
+# -  `Train a word-level language model using Recurrent LSTM networks`_
+# -  `More examples`_
+# -  `More tutorials`_
+# -  `Discuss PyTorch on the Forums`_
+# -  `Chat with other users on Slack`_
+#
+# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
+# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
+# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
+# .. _More examples: https://github.com/pytorch/examples
+# .. _More tutorials: https://github.com/pytorch/tutorials
+# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
+# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
+
+# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
+del dataiter
+# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%

+ 21 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/pytorch/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.pytorch_latest

@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM pytorch-latest-gpu-base:0.0.1
+
+RUN mkdir -p /test/data
+RUN chmod -R 777 /test
+ADD cifar10_tutorial.py /test/cifar10_tutorial.py

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/base/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/build-all.sh → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/build-all.sh


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.cpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1 → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/Dockerfile.gpu.tf_1.13.1


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/README.md


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_main.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_model.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/cifar10_utils.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/generate_cifar10_tfrecords.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/with-cifar10-models/ubuntu-16.04/cifar10_estimator_tf_1.13.1/model_base.py


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/Dockerfile.gpu → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/Dockerfile.gpu


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/run_container.sh → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/run_container.sh


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/shiro.ini → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/shiro.ini


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/main/docker/zeppelin-notebook-example/zeppelin-site.xml → hadoop-submarine/hadoop-submarine-core/src/main/docker/tensorflow/zeppelin-notebook-example/zeppelin-site.xml


+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Cli.java

@@ -15,6 +15,7 @@
 package org.apache.hadoop.yarn.submarine.client.cli;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;

+ 2 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliConstants.java

@@ -59,4 +59,6 @@ public class CliConstants {
   public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
   public static final String YAML_CONFIG = "f";
   public static final String INSECURE_CLUSTER = "insecure";
+
+  public static final String FRAMEWORK = "framework";
 }

+ 1 - 1
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/CliUtils.java

@@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.slf4j.Logger;

+ 24 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/Command.java

@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli;
+
+/**
+ * Represents a Submarine command.
+ */
+public enum Command {
+  RUN_JOB, SHOW_JOB
+}

+ 6 - 6
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/ShowJobCli.java

@@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
   private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
 
   private Options options;
-  private ShowJobParameters parameters = new ShowJobParameters();
+  private ParametersHolder parametersHolder;
 
   public ShowJobCli(ClientContext cliContext) {
     super(cliContext);
@@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
     CommandLine cli;
     try {
       cli = parser.parse(options, args);
-      ParametersHolder parametersHolder = ParametersHolder
-          .createWithCmdLine(cli);
-      parameters.updateParameters(parametersHolder, clientContext);
+      parametersHolder = ParametersHolder
+          .createWithCmdLine(cli, Command.SHOW_JOB);
+      parametersHolder.updateParameters(clientContext);
     } catch (ParseException e) {
       printUsages();
     }
@@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
 
     Map<String, String> jobInfo = null;
     try {
-      jobInfo = storage.getJobInfoByName(parameters.getName());
+      jobInfo = storage.getJobInfoByName(getParameters().getName());
     } catch (IOException e) {
       LOG.error("Failed to retrieve job info", e);
       throw e;
@@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
 
   @VisibleForTesting
   public ShowJobParameters getParameters() {
-    return parameters;
+    return (ShowJobParameters) parametersHolder.getParameters();
   }
 
   @Override

+ 24 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ConfigType.java

@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param;
+
+/**
+ * Represents the source of configuration.
+ */
+public enum ConfigType {
+  YAML, CLI
+}

+ 134 - 9
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/ParametersHolder.java

@@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.ParseException;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.Command;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
@@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
+
 /**
  * This class acts as a wrapper of {@code CommandLine} values along with
  * YAML configuration values.
@@ -52,17 +63,110 @@ public final class ParametersHolder {
   private static final Logger LOG =
       LoggerFactory.getLogger(ParametersHolder.class);
 
+  public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
+      "TensorFlow and PyTorch are the only supported frameworks for now!";
+  public static final String SUPPORTED_COMMANDS_MESSAGE =
+      "'Show job' and 'run job' are the only supported commands for now!";
+
+
+
   private final CommandLine parsedCommandLine;
   private final Map<String, String> yamlStringConfigs;
   private final Map<String, List<String>> yamlListConfigs;
-  private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of(
+  private final ConfigType configType;
+  private Command command;
+  private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
       CliConstants.VERBOSE);
+  private final Framework framework;
+  private final BaseParameters parameters;
 
   private ParametersHolder(CommandLine parsedCommandLine,
-      YamlConfigFile yamlConfig) {
+      YamlConfigFile yamlConfig, ConfigType configType, Command command)
+      throws ParseException, YarnException {
     this.parsedCommandLine = parsedCommandLine;
     this.yamlStringConfigs = initStringConfigValues(yamlConfig);
     this.yamlListConfigs = initListConfigValues(yamlConfig);
+    this.configType = configType;
+    this.command = command;
+    this.framework = determineFrameworkType();
+    this.ensureOnlyValidSectionsAreDefined(yamlConfig);
+    this.parameters = createParameters();
+  }
+
+  private BaseParameters createParameters() {
+    if (command == Command.RUN_JOB) {
+      if (framework == Framework.TENSORFLOW) {
+        return new TensorFlowRunJobParameters();
+      } else if (framework == Framework.PYTORCH) {
+        return new PyTorchRunJobParameters();
+      } else {
+        throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
+      }
+    } else if (command == Command.SHOW_JOB) {
+      return new ShowJobParameters();
+    } else {
+      throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
+    }
+  }
+
+  private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
+    if (isCommandRunJob() && isFrameworkPyTorch() &&
+        isPsSectionDefined(yamlConfig)) {
+      throw new YamlParseException(
+          "PS section should not be defined when PyTorch " +
+              "is the selected framework!");
+    }
+
+    if (isCommandRunJob() && isFrameworkPyTorch() &&
+        isTensorboardSectionDefined(yamlConfig)) {
+      throw new YamlParseException(
+          "TensorBoard section should not be defined when PyTorch " +
+              "is the selected framework!");
+    }
+  }
+
+  private boolean isCommandRunJob() {
+    return command == Command.RUN_JOB;
+  }
+
+  private boolean isFrameworkPyTorch() {
+    return framework == Framework.PYTORCH;
+  }
+
+  private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
+    return yamlConfig != null &&
+        yamlConfig.getRoles() != null &&
+        yamlConfig.getRoles().getPs() != null;
+  }
+
+  private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
+    return yamlConfig != null &&
+        yamlConfig.getTensorBoard() != null;
+  }
+
+  private Framework determineFrameworkType()
+      throws ParseException, YarnException {
+    if (!isCommandRunJob()) {
+      return null;
+    }
+    String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
+    if (frameworkStr == null) {
+      LOG.info("Framework is not defined in config, falling back to " +
+          "TensorFlow as a default.");
+      return Framework.TENSORFLOW;
+    }
+    Framework framework = Framework.parseByValue(frameworkStr);
+    if (framework == null) {
+      if (getConfigType() == ConfigType.CLI) {
+        throw new ParseException("Failed to parse Framework type! "
+            + "Valid values are: " + Framework.getValues());
+      } else {
+        throw new YamlParseException(YAML_PARSE_FAILED +
+            ", framework should is defined, but it has an invalid value! " +
+            "Valid values are: " + Framework.getValues());
+      }
+    }
+    return framework;
   }
 
   /**
@@ -108,6 +212,8 @@ public final class ParametersHolder {
   private void initGenericConfigs(YamlConfigFile yamlConfig,
       Map<String, String> yamlConfigs) {
     yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
+    yamlConfigs.put(CliConstants.FRAMEWORK,
+        yamlConfig.getSpec().getFramework());
 
     Configs configs = yamlConfig.getConfigs();
     yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
@@ -178,13 +284,15 @@ public final class ParametersHolder {
         .collect(Collectors.toList());
   }
 
-  public static ParametersHolder createWithCmdLine(CommandLine cli) {
-    return new ParametersHolder(cli, null);
+  public static ParametersHolder createWithCmdLine(CommandLine cli,
+      Command command) throws ParseException, YarnException {
+    return new ParametersHolder(cli, null, ConfigType.CLI, command);
   }
 
   public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
-      YamlConfigFile yamlConfig) {
-    return new ParametersHolder(cli, yamlConfig);
+      YamlConfigFile yamlConfig, Command command) throws ParseException,
+      YarnException {
+    return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
   }
 
   /**
@@ -193,7 +301,7 @@ public final class ParametersHolder {
    * @param option Name of the config.
    * @return The value of the config
    */
-  String getOptionValue(String option) throws YarnException {
+  public String getOptionValue(String option) throws YarnException {
     ensureConfigIsDefinedOnce(option, true);
     if (onlyDefinedWithCliArgs.contains(option) ||
         parsedCommandLine.hasOption(option)) {
@@ -208,7 +316,7 @@ public final class ParametersHolder {
    * @param option Name of the config.
    * @return The values of the config
    */
-  List<String> getOptionValues(String option) throws YarnException {
+  public List<String> getOptionValues(String option) throws YarnException {
     ensureConfigIsDefinedOnce(option, false);
     if (onlyDefinedWithCliArgs.contains(option) ||
         parsedCommandLine.hasOption(option)) {
@@ -285,7 +393,7 @@ public final class ParametersHolder {
    * @return true, if the option is found in the CLI args or in the YAML config,
    * false otherwise.
    */
-  boolean hasOption(String option) {
+  public boolean hasOption(String option) {
     if (onlyDefinedWithCliArgs.contains(option)) {
       boolean value = parsedCommandLine.hasOption(option);
       if (LOG.isDebugEnabled()) {
@@ -312,4 +420,21 @@ public final class ParametersHolder {
         "from YAML configuration.", result, option);
     return result;
   }
+
+  public ConfigType getConfigType() {
+    return configType;
+  }
+
+  public Framework getFramework() {
+    return framework;
+  }
+
+  public void updateParameters(ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    parameters.updateParameters(this, clientContext);
+  }
+
+  public BaseParameters getParameters() {
+    return parameters;
+  }
 }

+ 120 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/PyTorchRunJobParameters.java

@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Parameters for PyTorch job.
+ */
+public class PyTorchRunJobParameters extends RunJobParameters {
+
+  private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
+      "cannot be defined for PyTorch jobs!";
+
+  @Override
+  public void updateParameters(ParametersHolder parametersHolder,
+      ClientContext clientContext)
+      throws ParseException, IOException, YarnException {
+    checkArguments(parametersHolder);
+
+    super.updateParameters(parametersHolder, clientContext);
+
+    String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
+    this.workerParameters =
+        getWorkerParameters(clientContext, parametersHolder, input);
+    this.distributed = determineIfDistributed(workerParameters.getReplicas());
+    executePostOperations(clientContext);
+  }
+
+  private void checkArguments(ParametersHolder parametersHolder)
+      throws YarnException, ParseException {
+    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.N_PS));
+    } else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_RES));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_DOCKER_IMAGE));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.PS_LAUNCH_CMD));
+    } else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD_RESOURCES));
+    } else if (parametersHolder
+        .getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
+      throw new ParseException(getParamCannotBeDefinedErrorMessage(
+          CliConstants.TENSORBOARD_DOCKER_IMAGE));
+    }
+  }
+
+  private String getParamCannotBeDefinedErrorMessage(String cliName) {
+    return String.format(
+        "Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
+  }
+
+  @Override
+  void executePostOperations(ClientContext clientContext) throws IOException {
+    // Set default job dir / saved model dir, etc.
+    setDefaultDirs(clientContext);
+    replacePatternsInParameters(clientContext);
+  }
+
+  private void replacePatternsInParameters(ClientContext clientContext)
+      throws IOException {
+    if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
+      String afterReplace =
+          CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
+              clientContext.getRemoteDirectoryManager());
+      setWorkerLaunchCmd(afterReplace);
+    }
+  }
+
+  @Override
+  public List<String> getLaunchCommands() {
+    return Lists.newArrayList(getWorkerLaunchCmd());
+  }
+
+  /**
+   * We only support non-distributed PyTorch integration for now.
+   * @param nWorkers
+   * @return
+   */
+  private boolean determineIfDistributed(int nWorkers) {
+    return false;
+  }
+}

+ 146 - 197
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/RunJobParameters.java

@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 /**
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -12,7 +28,7 @@
  * limitations under the License. See accompanying LICENSE file.
  */
 
-package org.apache.hadoop.yarn.submarine.client.cli.param;
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.CaseFormat;
@@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
 import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.yaml.snakeyaml.introspector.Property;
 import org.yaml.snakeyaml.introspector.PropertyUtils;
@@ -34,27 +57,15 @@ import java.util.List;
 /**
  * Parameters used to run a job
  */
-public class RunJobParameters extends RunParameters {
+public abstract class RunJobParameters extends RunParameters {
   private String input;
   private String checkpointPath;
 
-  private int numWorkers;
-  private int numPS;
-  private Resource workerResource;
-  private Resource psResource;
-  private boolean tensorboardEnabled;
-  private Resource tensorboardResource;
-  private String tensorboardDockerImage;
-  private String workerLaunchCmd;
-  private String psLaunchCmd;
   private List<Quicklink> quicklinks = new ArrayList<>();
   private List<Localization> localizations = new ArrayList<>();
 
-  private String psDockerImage = null;
-  private String workerDockerImage = null;
-
   private boolean waitJobFinish = false;
-  private boolean distributed = false;
+  protected boolean distributed = false;
 
   private boolean securityDisabled = false;
   private String keytab;
@@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
   private boolean distributeKeytab = false;
   private List<String> confPairs = new ArrayList<>();
 
+  RoleParameters workerParameters =
+      RoleParameters.createEmpty(TensorFlowRole.WORKER);
+
   @Override
   public void updateParameters(ParametersHolder parametersHolder,
       ClientContext clientContext)
@@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
     String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
     String jobDir = parametersHolder.getOptionValue(
         CliConstants.CHECKPOINT_PATH);
-    int nWorkers = 1;
-    if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
-      nWorkers = Integer.parseInt(
-          parametersHolder.getOptionValue(CliConstants.N_WORKERS));
-      // Only check null value.
-      // Training job shouldn't ignore INPUT_PATH option
-      // But if nWorkers is 0, INPUT_PATH can be ignored because
-      // user can only run Tensorboard
-      if (null == input && 0 != nWorkers) {
-        throw new ParseException("\"--" + CliConstants.INPUT_PATH +
-            "\" is absent");
-      }
-    }
-
-    int nPS = 0;
-    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
-      nPS = Integer.parseInt(
-          parametersHolder.getOptionValue(CliConstants.N_PS));
-    }
-
-    // Check #workers and #ps.
-    // When distributed training is required
-    if (nWorkers >= 2 && nPS > 0) {
-      distributed = true;
-    } else if (nWorkers <= 1 && nPS > 0) {
-      throw new ParseException("Only specified one worker but non-zero PS, "
-          + "please double check.");
-    }
 
     if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
       setSecurityDisabled(true);
@@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
         CliConstants.PRINCIPAL);
     CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
 
-    workerResource = null;
-    if (nWorkers > 0) {
-      String workerResourceStr = parametersHolder.getOptionValue(
-          CliConstants.WORKER_RES);
-      if (workerResourceStr == null) {
-        throw new ParseException(
-            "--" + CliConstants.WORKER_RES + " is absent.");
-      }
-      workerResource = ResourceUtils.createResourceFromString(
-          workerResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-    }
-
-    Resource psResource = null;
-    if (nPS > 0) {
-      String psResourceStr = parametersHolder.getOptionValue(
-          CliConstants.PS_RES);
-      if (psResourceStr == null) {
-        throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
-      }
-      psResource = ResourceUtils.createResourceFromString(psResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-    }
-
-    boolean tensorboard = false;
-    if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
-      tensorboard = true;
-      String tensorboardResourceStr = parametersHolder.getOptionValue(
-          CliConstants.TENSORBOARD_RESOURCES);
-      if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
-        tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
-      }
-      tensorboardResource = ResourceUtils.createResourceFromString(
-          tensorboardResourceStr,
-          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
-      tensorboardDockerImage = parametersHolder.getOptionValue(
-          CliConstants.TENSORBOARD_DOCKER_IMAGE);
-      this.setTensorboardResource(tensorboardResource);
-    }
-
     if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
       this.waitJobFinish = true;
     }
@@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
       }
     }
 
-    psDockerImage = parametersHolder.getOptionValue(
-        CliConstants.PS_DOCKER_IMAGE);
-    workerDockerImage = parametersHolder.getOptionValue(
-        CliConstants.WORKER_DOCKER_IMAGE);
-
-    String workerLaunchCmd = parametersHolder.getOptionValue(
-        CliConstants.WORKER_LAUNCH_CMD);
-    String psLaunchCommand = parametersHolder.getOptionValue(
-        CliConstants.PS_LAUNCH_CMD);
-
     // Localizations
     List<String> localizationsStr = parametersHolder.getOptionValues(
         CliConstants.LOCALIZATION);
@@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
         .getOptionValues(CliConstants.ARG_CONF);
 
     this.setInputPath(input).setCheckpointPath(jobDir)
-        .setNumPS(nPS).setNumWorkers(nWorkers)
-        .setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
-        .setPsResource(psResource)
-        .setTensorboardEnabled(tensorboard)
         .setKeytab(kerberosKeytab)
         .setPrincipal(kerberosPrincipal)
         .setDistributeKeytab(distributeKerberosKeytab)
@@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
     super.updateParameters(parametersHolder, clientContext);
   }
 
+  abstract void executePostOperations(ClientContext clientContext)
+      throws IOException;
+
+  void setDefaultDirs(ClientContext clientContext) throws IOException {
+    // Create directories if needed
+    String jobDir = getCheckpointPath();
+    if (jobDir == null) {
+      jobDir = getJobDir(clientContext);
+      setCheckpointPath(jobDir);
+    }
+
+    if (getNumWorkers() > 0) {
+      String savedModelDir = getSavedModelPath();
+      if (savedModelDir == null) {
+        savedModelDir = jobDir;
+        setSavedModelPath(savedModelDir);
+      }
+    }
+  }
+
+  private String getJobDir(ClientContext clientContext) throws IOException {
+    RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
+    if (getNumWorkers() > 0) {
+      return rdm.getJobCheckpointDir(getName(), true).toString();
+    } else {
+      // when #workers == 0, it means we only launch TB. In that case,
+      // point job dir to root dir so all job's metrics will be shown.
+      return rdm.getUserRootFolder().toString();
+    }
+  }
+
+  public abstract List<String> getLaunchCommands();
+
   public String getInputPath() {
     return input;
   }
@@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
     return this;
   }
 
-  public int getNumWorkers() {
-    return numWorkers;
-  }
-
-  public RunJobParameters setNumWorkers(int numWorkers) {
-    this.numWorkers = numWorkers;
-    return this;
-  }
-
-  public int getNumPS() {
-    return numPS;
-  }
-
-  public RunJobParameters setNumPS(int numPS) {
-    this.numPS = numPS;
-    return this;
-  }
-
-  public Resource getWorkerResource() {
-    return workerResource;
-  }
-
-  public RunJobParameters setWorkerResource(Resource workerResource) {
-    this.workerResource = workerResource;
-    return this;
-  }
-
-  public Resource getPsResource() {
-    return psResource;
-  }
-
-  public RunJobParameters setPsResource(Resource psResource) {
-    this.psResource = psResource;
-    return this;
-  }
-
-  public boolean isTensorboardEnabled() {
-    return tensorboardEnabled;
-  }
-
-  public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
-    this.tensorboardEnabled = tensorboardEnabled;
-    return this;
-  }
-
-  public String getWorkerLaunchCmd() {
-    return workerLaunchCmd;
-  }
-
-  public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
-    this.workerLaunchCmd = workerLaunchCmd;
-    return this;
-  }
-
-  public String getPSLaunchCmd() {
-    return psLaunchCmd;
-  }
-
-  public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
-    this.psLaunchCmd = psLaunchCmd;
-    return this;
-  }
-
   public boolean isWaitJobFinish() {
     return waitJobFinish;
   }
 
-
-  public String getPsDockerImage() {
-    return psDockerImage;
-  }
-
-  public void setPsDockerImage(String psDockerImage) {
-    this.psDockerImage = psDockerImage;
-  }
-
-  public String getWorkerDockerImage() {
-    return workerDockerImage;
-  }
-
-  public void setWorkerDockerImage(String workerDockerImage) {
-    this.workerDockerImage = workerDockerImage;
-  }
-
-  public boolean isDistributed() {
-    return distributed;
-  }
-
-  public Resource getTensorboardResource() {
-    return tensorboardResource;
-  }
-
-  public void setTensorboardResource(Resource tensorboardResource) {
-    this.tensorboardResource = tensorboardResource;
-  }
-
-  public String getTensorboardDockerImage() {
-    return tensorboardDockerImage;
-  }
-
-  public void setTensorboardDockerImage(String tensorboardDockerImage) {
-    this.tensorboardDockerImage = tensorboardDockerImage;
-  }
-
   public List<Quicklink> getQuicklinks() {
     return quicklinks;
   }
@@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
     this.distributed = distributed;
   }
 
+  RoleParameters getWorkerParameters(ClientContext clientContext,
+      ParametersHolder parametersHolder, String input)
+      throws ParseException, YarnException, IOException {
+    int nWorkers = getNumberOfWorkers(parametersHolder, input);
+    Resource workerResource =
+        determineWorkerResource(parametersHolder, nWorkers, clientContext);
+    String workerDockerImage =
+        parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
+    String workerLaunchCmd =
+        parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
+    return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
+        workerLaunchCmd, workerDockerImage, workerResource);
+  }
+
+  private Resource determineWorkerResource(ParametersHolder parametersHolder,
+      int nWorkers, ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    if (nWorkers > 0) {
+      String workerResourceStr =
+          parametersHolder.getOptionValue(CliConstants.WORKER_RES);
+      if (workerResourceStr == null) {
+        throw new ParseException(
+            "--" + CliConstants.WORKER_RES + " is absent.");
+      }
+      return ResourceUtils.createResourceFromString(workerResourceStr,
+          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    }
+    return null;
+  }
+
+  private int getNumberOfWorkers(ParametersHolder parametersHolder,
+      String input) throws ParseException, YarnException {
+    int nWorkers = 1;
+    if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
+      nWorkers = Integer
+          .parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
+      // Only check null value.
+      // Training job shouldn't ignore INPUT_PATH option
+      // But if nWorkers is 0, INPUT_PATH can be ignored because
+      // user can only run Tensorboard
+      if (null == input && 0 != nWorkers) {
+        throw new ParseException(
+            "\"--" + CliConstants.INPUT_PATH + "\" is absent");
+      }
+    }
+    return nWorkers;
+  }
+
+  public String getWorkerLaunchCmd() {
+    return workerParameters.getLaunchCommand();
+  }
+
+  public void setWorkerLaunchCmd(String launchCmd) {
+    workerParameters.setLaunchCommand(launchCmd);
+  }
+
+  public int getNumWorkers() {
+    return workerParameters.getReplicas();
+  }
+
+  public void setNumWorkers(int numWorkers) {
+    workerParameters.setReplicas(numWorkers);
+  }
+
+  public Resource getWorkerResource() {
+    return workerParameters.getResource();
+  }
+
+  public void setWorkerResource(Resource resource) {
+    workerParameters.setResource(resource);
+  }
+
+  public String getWorkerDockerImage() {
+    return workerParameters.getDockerImage();
+  }
+
+  public void setWorkerDockerImage(String image) {
+    workerParameters.setDockerImage(image);
+  }
+
+  public boolean isDistributed() {
+    return distributed;
+  }
+
   @VisibleForTesting
   public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
     @Override

+ 215 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/TensorFlowRunJobParameters.java

@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Parameters for TensorFlow job.
+ */
+public class TensorFlowRunJobParameters extends RunJobParameters {
+  private boolean tensorboardEnabled;
+  private RoleParameters psParameters =
+      RoleParameters.createEmpty(TensorFlowRole.PS);
+  private RoleParameters tensorBoardParameters =
+      RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
+
+  @Override
+  public void updateParameters(ParametersHolder parametersHolder,
+      ClientContext clientContext)
+      throws ParseException, IOException, YarnException {
+    super.updateParameters(parametersHolder, clientContext);
+
+    String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
+    this.workerParameters =
+        getWorkerParameters(clientContext, parametersHolder, input);
+    this.psParameters = getPSParameters(clientContext, parametersHolder);
+    this.distributed = determineIfDistributed(workerParameters.getReplicas(),
+        psParameters.getReplicas());
+
+    if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
+      this.tensorboardEnabled = true;
+      this.tensorBoardParameters =
+          getTensorBoardParameters(parametersHolder, clientContext);
+    }
+    executePostOperations(clientContext);
+  }
+
+  @Override
+  void executePostOperations(ClientContext clientContext) throws IOException {
+    // Set default job dir / saved model dir, etc.
+    setDefaultDirs(clientContext);
+    replacePatternsInParameters(clientContext);
+  }
+
+  private void replacePatternsInParameters(ClientContext clientContext)
+      throws IOException {
+    if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
+      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
+          getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
+      setPSLaunchCmd(afterReplace);
+    }
+
+    if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
+      String afterReplace =
+          CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
+              clientContext.getRemoteDirectoryManager());
+      setWorkerLaunchCmd(afterReplace);
+    }
+  }
+
+  @Override
+  public List<String> getLaunchCommands() {
+    return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
+  }
+
+  private boolean determineIfDistributed(int nWorkers, int nPS)
+      throws ParseException {
+    // Check #workers and #ps.
+    // When distributed training is required
+    if (nWorkers >= 2 && nPS > 0) {
+      return true;
+    } else if (nWorkers <= 1 && nPS > 0) {
+      throw new ParseException("Only specified one worker but non-zero PS, "
+          + "please double check.");
+    }
+    return false;
+  }
+
+  private RoleParameters getPSParameters(ClientContext clientContext,
+      ParametersHolder parametersHolder)
+      throws YarnException, IOException, ParseException {
+    int nPS = getNumberOfPS(parametersHolder);
+    Resource psResource =
+        determinePSResource(parametersHolder, nPS, clientContext);
+    String psDockerImage =
+        parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
+    String psLaunchCommand =
+        parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
+    return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
+        psDockerImage, psResource);
+  }
+
+  private Resource determinePSResource(ParametersHolder parametersHolder,
+      int nPS, ClientContext clientContext)
+      throws ParseException, YarnException, IOException {
+    if (nPS > 0) {
+      String psResourceStr =
+          parametersHolder.getOptionValue(CliConstants.PS_RES);
+      if (psResourceStr == null) {
+        throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
+      }
+      return ResourceUtils.createResourceFromString(psResourceStr,
+          clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    }
+    return null;
+  }
+
+  private int getNumberOfPS(ParametersHolder parametersHolder)
+      throws YarnException {
+    int nPS = 0;
+    if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
+      nPS =
+          Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
+    }
+    return nPS;
+  }
+
+  private RoleParameters getTensorBoardParameters(
+      ParametersHolder parametersHolder, ClientContext clientContext)
+      throws YarnException, IOException {
+    String tensorboardResourceStr =
+        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
+    if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
+      tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
+    }
+    Resource tensorboardResource =
+        ResourceUtils.createResourceFromString(tensorboardResourceStr,
+            clientContext.getOrCreateYarnClient().getResourceTypeInfo());
+    String tensorboardDockerImage =
+        parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
+    return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
+        tensorboardDockerImage, tensorboardResource);
+  }
+
+  public int getNumPS() {
+    return psParameters.getReplicas();
+  }
+
+  public void setNumPS(int numPS) {
+    psParameters.setReplicas(numPS);
+  }
+
+  public Resource getPsResource() {
+    return psParameters.getResource();
+  }
+
+  public void setPsResource(Resource resource) {
+    psParameters.setResource(resource);
+  }
+
+  public String getPsDockerImage() {
+    return psParameters.getDockerImage();
+  }
+
+  public void setPsDockerImage(String image) {
+    psParameters.setDockerImage(image);
+  }
+
+  public String getPSLaunchCmd() {
+    return psParameters.getLaunchCommand();
+  }
+
+  public void setPSLaunchCmd(String launchCmd) {
+    psParameters.setLaunchCommand(launchCmd);
+  }
+
+  public boolean isTensorboardEnabled() {
+    return tensorboardEnabled;
+  }
+
+  public Resource getTensorboardResource() {
+    return tensorBoardParameters.getResource();
+  }
+
+  public void setTensorboardResource(Resource resource) {
+    tensorBoardParameters.setResource(resource);
+  }
+
+  public String getTensorboardDockerImage() {
+    return tensorBoardParameters.getDockerImage();
+  }
+
+  public void setTensorboardDockerImage(String image) {
+    tensorBoardParameters.setDockerImage(image);
+  }
+
+}

+ 20 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/runjob/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes that hold run job parameters for
+ * TensorFlow / PyTorch jobs.
+ */
+package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;

+ 9 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/yaml/Spec.java

@@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
 public class Spec {
   private String name;
   private String jobType;
+  private String framework;
 
   public String getJobType() {
     return jobType;
@@ -38,4 +39,12 @@ public class Spec {
   public void setName(String name) {
     this.name = name;
   }
+
+  public String getFramework() {
+    return framework;
+  }
+
+  public void setFramework(String framework) {
+    this.framework = framework;
+  }
 }

+ 59 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/Framework.java

@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Represents the type of Machine learning framework to work with.
+ */
+public enum Framework {
+  TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
+
+  private String value;
+
+  Framework(String value) {
+    this.value = value;
+  }
+
+  public String getValue() {
+    return value;
+  }
+
+  public static Framework parseByValue(String value) {
+    for (Framework fw : Framework.values()) {
+      if (fw.value.equalsIgnoreCase(value)) {
+        return fw;
+      }
+    }
+    return null;
+  }
+
+  public static String getValues() {
+    List<String> values = Lists.newArrayList(Framework.values()).stream()
+        .map(fw -> fw.value).collect(Collectors.toList());
+    return String.join(",", values);
+  }
+
+  private static class Constants {
+    static final String TENSORFLOW_NAME = "tensorflow";
+    static final String PYTORCH_NAME = "pytorch";
+  }
+}

+ 81 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RoleParameters.java

@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+
+/**
+ * This class encapsulates data related to a particular Role.
+ * Some examples: TF Worker process, TF PS process or a PyTorch worker process.
+ */
+public class RoleParameters {
+  private final Role role;
+  private int replicas;
+  private String launchCommand;
+  private String dockerImage;
+  private Resource resource;
+
+  public RoleParameters(Role role, int replicas,
+      String launchCommand, String dockerImage, Resource resource) {
+    this.role = role;
+    this.replicas = replicas;
+    this.launchCommand = launchCommand;
+    this.dockerImage = dockerImage;
+    this.resource = resource;
+  }
+
+  public static RoleParameters createEmpty(Role role) {
+    return new RoleParameters(role, 0, null, null, null);
+  }
+
+  public Role getRole() {
+    return role;
+  }
+
+  public int getReplicas() {
+    return replicas;
+  }
+
+  public String getLaunchCommand() {
+    return launchCommand;
+  }
+
+  public void setLaunchCommand(String launchCommand) {
+    this.launchCommand = launchCommand;
+  }
+
+  public String getDockerImage() {
+    return dockerImage;
+  }
+
+  public void setDockerImage(String dockerImage) {
+    this.dockerImage = dockerImage;
+  }
+
+  public Resource getResource() {
+    return resource;
+  }
+
+  public void setResource(Resource resource) {
+    this.resource = resource;
+  }
+
+  public void setReplicas(int replicas) {
+    this.replicas = replicas;
+  }
+}

+ 123 - 106
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/RunJobCli.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/RunJobCli.java

@@ -1,3 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 /**
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -12,7 +28,7 @@
  * limitations under the License. See accompanying LICENSE file.
  */
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.cli.CommandLine;
@@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
 import org.apache.commons.io.FileUtils;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.Command;
 import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
@@ -44,17 +64,25 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+/**
+ * This purpose of this class is to handle / parse CLI arguments related to
+ * the run job Submarine command.
+ */
 public class RunJobCli extends AbstractCli {
   private static final Logger LOG =
       LoggerFactory.getLogger(RunJobCli.class);
-  private static final String YAML_PARSE_FAILED = "Failed to parse " +
+  private static final String CAN_BE_USED_WITH_TF_PYTORCH =
+      "Can be used with TensorFlow or PyTorch frameworks.";
+  private static final String CAN_BE_USED_WITH_TF_ONLY =
+      "Can only be used with TensorFlow framework.";
+  public static final String YAML_PARSE_FAILED = "Failed to parse " +
       "YAML config";
 
-  private Options options;
-  private RunJobParameters parameters = new RunJobParameters();
 
+  private Options options;
   private JobSubmitter jobSubmitter;
   private JobMonitor jobMonitor;
+  private ParametersHolder parametersHolder;
 
   public RunJobCli(ClientContext cliContext) {
     this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
@@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
   }
 
   @VisibleForTesting
-  RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
+  public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
       JobMonitor jobMonitor) {
     super(cliContext);
     this.options = generateOptions();
@@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
     Options options = new Options();
     options.addOption(CliConstants.YAML_CONFIG, true,
         "Config file (in YAML format)");
+    options.addOption(CliConstants.FRAMEWORK, true,
+        String.format("Framework to use. Valid values are: %s! " +
+                "The default framework is Tensorflow.",
+            Framework.getValues()));
     options.addOption(CliConstants.NAME, true, "Name of the job");
     options.addOption(CliConstants.INPUT_PATH, true,
         "Input of the job, could be local or other FS directory");
@@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
     options.addOption(CliConstants.SAVED_MODEL_PATH, true,
         "Model exported path (savedmodel) of the job, which is needed when "
             + "exported model is not placed under ${checkpoint_path}"
-            + "could be local or other FS directory. This will be used to serve.");
-    options.addOption(CliConstants.N_WORKERS, true,
-        "Number of worker tasks of the job, by default it's 1");
-    options.addOption(CliConstants.N_PS, true,
-        "Number of PS tasks of the job, by default it's 0");
-    options.addOption(CliConstants.WORKER_RES, true,
-        "Resource of each worker, for example "
-            + "memory-mb=2048,vcores=2,yarn.io/gpu=2");
-    options.addOption(CliConstants.PS_RES, true,
-        "Resource of each PS, for example "
-            + "memory-mb=2048,vcores=2,yarn.io/gpu=2");
+            + "could be local or other FS directory. " +
+            "This will be used to serve.");
     options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
     options.addOption(CliConstants.QUEUE, true,
         "Name of queue to run the job, by default it uses default queue");
-    options.addOption(CliConstants.TENSORBOARD, false,
-        "Should we run TensorBoard"
-            + " for this job? By default it's disabled");
-    options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
-        "Specify resources of Tensorboard, by default it is "
-            + CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
-    options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
-        "Specify Tensorboard docker image. when this is not "
-            + "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
-            + " as default.");
-    options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
-        "Commandline of worker, arguments will be "
-            + "directly used to launch the worker");
-    options.addOption(CliConstants.PS_LAUNCH_CMD, true,
-        "Commandline of worker, arguments will be "
-            + "directly used to launch the PS");
+
+    addWorkerOptions(options);
+    addPSOptions(options);
+    addTensorboardOptions(options);
+
     options.addOption(CliConstants.ENV, true,
         "Common environment variable of worker/ps");
     options.addOption(CliConstants.VERBOSE, false,
         "Print verbose log for troubleshooting");
     options.addOption(CliConstants.WAIT_JOB_FINISH, false,
         "Specified when user want to wait the job finish");
-    options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
-        "Specify docker image for PS, when this is not specified, PS uses --"
-            + CliConstants.DOCKER_IMAGE + " as default.");
-    options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
-        "Specify docker image for WORKER, when this is not specified, WORKER "
-            + "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
     options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
         + "web UI shows link to given role instance and port. When "
         + "--tensorboard is specified, quicklink to tensorboard instance will "
@@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
     return options;
   }
 
-  private void replacePatternsInParameters() throws IOException {
-    if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd()
-        .isEmpty()) {
-      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
-          parameters.getPSLaunchCmd(), parameters,
-          clientContext.getRemoteDirectoryManager());
-      parameters.setPSLaunchCmd(afterReplace);
-    }
+  private void addWorkerOptions(Options options) {
+    options.addOption(CliConstants.N_WORKERS, true,
+        "Number of worker tasks of the job, by default it's 1." +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
+        "Specify docker image for WORKER, when this is not specified, WORKER "
+            + "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
+        "Commandline of worker, arguments will be "
+            + "directly used to launch the worker" +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+    options.addOption(CliConstants.WORKER_RES, true,
+        "Resource of each worker, for example "
+            + "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
+            CAN_BE_USED_WITH_TF_PYTORCH);
+  }
 
-    if (parameters.getWorkerLaunchCmd() != null && !parameters
-        .getWorkerLaunchCmd().isEmpty()) {
-      String afterReplace = CliUtils.replacePatternsInLaunchCommand(
-          parameters.getWorkerLaunchCmd(), parameters,
-          clientContext.getRemoteDirectoryManager());
-      parameters.setWorkerLaunchCmd(afterReplace);
-    }
+  private void addPSOptions(Options options) {
+    options.addOption(CliConstants.N_PS, true,
+        "Number of PS tasks of the job, by default it's 0. " +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
+        "Specify docker image for PS, when this is not specified, PS uses --"
+            + CliConstants.DOCKER_IMAGE + " as default." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_LAUNCH_CMD, true,
+        "Commandline of worker, arguments will be "
+            + "directly used to launch the PS" +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.PS_RES, true,
+        "Resource of each PS, for example "
+            + "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
+            CAN_BE_USED_WITH_TF_ONLY);
+  }
+
+  private void addTensorboardOptions(Options options) {
+    options.addOption(CliConstants.TENSORBOARD, false,
+        "Should we run TensorBoard"
+            + " for this job? By default it's disabled." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
+        "Specify resources of Tensorboard, by default it is "
+            + CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
+            CAN_BE_USED_WITH_TF_ONLY);
+    options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
+        "Specify Tensorboard docker image. when this is not "
+            + "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+            + " as default." +
+            CAN_BE_USED_WITH_TF_ONLY);
   }
 
   private void parseCommandLineAndGetRunJobParameters(String[] args)
       throws ParseException, IOException, YarnException {
     try {
-      // Do parsing
       GnuParser parser = new GnuParser();
       CommandLine cli = parser.parse(options, args);
-      ParametersHolder parametersHolder = createParametersHolder(cli);
-      parameters.updateParameters(parametersHolder, clientContext);
+      parametersHolder = createParametersHolder(cli);
+      parametersHolder.updateParameters(clientContext);
     } catch (ParseException e) {
       LOG.error("Exception in parse: {}", e.getMessage());
       printUsages();
       throw e;
     }
-
-    // Set default job dir / saved model dir, etc.
-    setDefaultDirs();
-
-    // replace patterns
-    replacePatternsInParameters();
   }
 
-  private ParametersHolder createParametersHolder(CommandLine cli) {
+  private ParametersHolder createParametersHolder(CommandLine cli)
+      throws ParseException, YarnException {
     String yamlConfigFile =
         cli.getOptionValue(CliConstants.YAML_CONFIG);
     if (yamlConfigFile != null) {
       YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
-      if (yamlConfig == null) {
-        throw new YamlParseException(String.format(
-            YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
-      } else if (yamlConfig.getConfigs() == null) {
-        throw new YamlParseException(String.format(YAML_PARSE_FAILED +
-            ", config section should be defined, but it cannot be found in " +
-            "YAML file '%s'!", yamlConfigFile));
-      }
+      checkYamlConfig(yamlConfigFile, yamlConfig);
       LOG.info("Using YAML configuration!");
-      return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig);
+      return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
+          Command.RUN_JOB);
     } else {
       LOG.info("Using CLI configuration!");
-      return ParametersHolder.createWithCmdLine(cli);
+      return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
+    }
+  }
+
+  private void checkYamlConfig(String yamlConfigFile,
+      YamlConfigFile yamlConfig) {
+    if (yamlConfig == null) {
+      throw new YamlParseException(String.format(
+          YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
+    } else if (yamlConfig.getConfigs() == null) {
+      throw new YamlParseException(String.format(YAML_PARSE_FAILED +
+          ", config section should be defined, but it cannot be found in " +
+          "YAML file '%s'!", yamlConfigFile));
     }
   }
 
@@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
         e);
   }
 
-  private void setDefaultDirs() throws IOException {
-    // Create directories if needed
-    String jobDir = parameters.getCheckpointPath();
-    if (null == jobDir) {
-      if (parameters.getNumWorkers() > 0) {
-        jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
-            parameters.getName(), true).toString();
-      } else {
-        // when #workers == 0, it means we only launch TB. In that case,
-        // point job dir to root dir so all job's metrics will be shown.
-        jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
-            .toString();
-      }
-      parameters.setCheckpointPath(jobDir);
-    }
-
-    if (parameters.getNumWorkers() > 0) {
-      // Only do this when #worker > 0
-      String savedModelDir = parameters.getSavedModelPath();
-      if (null == savedModelDir) {
-        savedModelDir = jobDir;
-        parameters.setSavedModelPath(savedModelDir);
-      }
-    }
-  }
-
-  private void storeJobInformation(String jobName, ApplicationId applicationId,
-      String[] args) throws IOException {
+  private void storeJobInformation(RunJobParameters parameters,
+      ApplicationId applicationId, String[] args) throws IOException {
+    String jobName = parameters.getName();
     Map<String, String> jobInfo = new HashMap<>();
     jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
     jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
@@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
     }
 
     parseCommandLineAndGetRunJobParameters(args);
-    ApplicationId applicationId = this.jobSubmitter.submitJob(parameters);
-    storeJobInformation(parameters.getName(), applicationId, args);
+    ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
+    RunJobParameters parameters =
+        (RunJobParameters) parametersHolder.getParameters();
+    storeJobInformation(parameters, applicationId, args);
     if (parameters.isWaitJobFinish()) {
       this.jobMonitor.waitTrainingFinal(parameters.getName());
     }
@@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
 
   @VisibleForTesting
   public RunJobParameters getRunJobParameters() {
-    return parameters;
+    return (RunJobParameters) parametersHolder.getParameters();
   }
 }

+ 19 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/package-info.java

@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes that are related to the run job command.
+ */
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;

+ 54 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/PyTorchRole.java

@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. See accompanying LICENSE file.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+/**
+ * Enum to represent a PyTorch Role.
+ */
+public enum PyTorchRole implements Role {
+  PRIMARY_WORKER("master"),
+  WORKER("worker");
+
+  private String compName;
+
+  PyTorchRole(String compName) {
+    this.compName = compName;
+  }
+
+  public String getComponentName() {
+    return compName;
+  }
+
+  @Override
+  public String getName() {
+    return name();
+  }
+}

+ 25 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Role.java

@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+/**
+ * Interface for a Role.
+ */
+public interface Role {
+  String getComponentName();
+  String getName();
+}

+ 58 - 0
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/Runtime.java

@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.common.api;
+
+import com.google.common.collect.Lists;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Represents the type of Runtime.
+ */
+public enum Runtime {
+  TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
+
+  private String value;
+
+  Runtime(String value) {
+    this.value = value;
+  }
+
+  public String getValue() {
+    return value;
+  }
+
+  public static Runtime parseByValue(String value) {
+    for (Runtime rt : Runtime.values()) {
+      if (rt.value.equalsIgnoreCase(value)) {
+        return rt;
+      }
+    }
+    return null;
+  }
+
+  public static String getValues() {
+    List<String> values = Lists.newArrayList(Runtime.values()).stream()
+        .map(rt -> rt.value).collect(Collectors.toList());
+    return String.join(",", values);
+  }
+
+  public static class Constants {
+    public static final String TONY = "tony";
+    public static final String YARN_SERVICE = "yarnservice";
+  }
+}

+ 11 - 2
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TaskType.java → hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/common/api/TensorFlowRole.java

@@ -14,7 +14,10 @@
 
 package org.apache.hadoop.yarn.submarine.common.api;
 
-public enum TaskType {
+/**
+ * Enum to represent a TensorFlow Role.
+ */
+public enum TensorFlowRole implements Role {
   PRIMARY_WORKER("master"),
   WORKER("worker"),
   PS("ps"),
@@ -22,11 +25,17 @@ public enum TaskType {
 
   private String compName;
 
-  TaskType(String compName) {
+  TensorFlowRole(String compName) {
     this.compName = compName;
   }
 
+  @Override
   public String getComponentName() {
     return compName;
   }
+
+  @Override
+  public String getName() {
+    return name();
+  }
 }

+ 5 - 5
hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/common/JobSubmitter.java

@@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
 
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
 
 import java.io.IOException;
 
 /**
- * Submit job to cluster master
+ * Submit job to cluster master.
  */
 public interface JobSubmitter {
   /**
-   * Submit job to cluster
+   * Submit a job to cluster.
    * @param parameters run job parameters
-   * @return applicatioId when successfully submitted
+   * @return applicationId when successfully submitted
    * @throws YarnException for issues while contacting YARN daemons
    * @throws IOException for other issues.
    */
-  ApplicationId submitJob(RunJobParameters parameters)
+  ApplicationId submitJob(ParametersHolder parameters)
       throws IOException, YarnException;
 }

+ 7 - 0
hadoop-submarine/hadoop-submarine-core/src/site/markdown/QuickStart.md

@@ -40,6 +40,10 @@ More details, please refer to
 
 ```$xslt
 usage: job run
+
+ -framework <arg>             Framework to use.
+                              Valid values are: tensorflow, pytorch.
+                              The default framework is Tensorflow.
  -checkpoint_path <arg>       Training output directory of the job, could
                               be local or other FS directory. This
                               typically includes checkpoint files and
@@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
 #### Commandline
 ```
 yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
+  --framework tensorflow \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
   --docker_image <your-docker-image> \
@@ -163,6 +168,7 @@ See below screenshot:
 ```
 yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
  --name tf-job-001 --docker_image <your-docker-image> \
+ --framework tensorflow \
  --input_path hdfs://default/dataset/cifar-10-data \
  --checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
  --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
@@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
 yarn app -destroy tensorboard-service; \
 yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
   job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
+  --framework tensorflow \
   --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
   --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
   --num_workers 0 --tensorboard

+ 0 - 226
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java

@@ -1,226 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.hadoop.yarn.submarine.client.cli;
-
-import org.apache.commons.cli.ParseException;
-import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.MockClientContext;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
-import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
-import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
-import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
-import org.apache.hadoop.yarn.util.resource.Resources;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.io.IOException;
-
-import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-public class TestRunJobCliParsing {
-
-  @Before
-  public void before() {
-    SubmarineLogs.verboseOff();
-  }
-
-  @Test
-  public void testPrintHelp() {
-    MockClientContext mockClientContext = new MockClientContext();
-    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    JobMonitor mockJobMonitor = mock(JobMonitor.class);
-    RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
-        mockJobMonitor);
-    runJobCli.printUsages();
-  }
-
-  static MockClientContext getMockClientContext()
-      throws IOException, YarnException {
-    MockClientContext mockClientContext = new MockClientContext();
-    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
-        ApplicationId.newInstance(1234L, 1));
-    JobMonitor mockJobMonitor = mock(JobMonitor.class);
-    SubmarineStorage storage = mock(SubmarineStorage.class);
-    RuntimeFactory rtFactory = mock(RuntimeFactory.class);
-
-    when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
-    when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
-    when(rtFactory.getSubmarineStorage()).thenReturn(storage);
-
-    mockClientContext.setRuntimeFactory(rtFactory);
-    return mockClientContext;
-  }
-
-  @Test
-  public void testBasicRunJobForDistributedTraining() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
-            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
-            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
-            "--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
-            "--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
-            "--principal", "user/_HOST@domain.com", "--distribute_keytab",
-            "--verbose" });
-
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
-    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
-    assertEquals(jobRunParameters.getNumPS(), 2);
-    assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
-    assertEquals(Resources.createResource(4096, 4),
-        jobRunParameters.getPsResource());
-    assertEquals(jobRunParameters.getWorkerLaunchCmd(),
-        "python run-job.py");
-    assertEquals(Resources.createResource(2048, 2),
-        jobRunParameters.getWorkerResource());
-    assertEquals(jobRunParameters.getDockerImageName(),
-        "tf-docker:1.1.0");
-    assertEquals(jobRunParameters.getKeytab(),
-        "/keytab/path");
-    assertEquals(jobRunParameters.getPrincipal(),
-        "user/_HOST@domain.com");
-    Assert.assertTrue(jobRunParameters.isDistributeKeytab());
-    Assert.assertTrue(SubmarineLogs.isVerbose());
-  }
-
-  @Test
-  public void testBasicRunJobForSingleNodeTraining() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
-            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
-            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
-            "true", "--verbose", "--wait_job_finish" });
-
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
-    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
-    assertEquals(jobRunParameters.getNumWorkers(), 1);
-    assertEquals(jobRunParameters.getWorkerLaunchCmd(),
-        "python run-job.py");
-    assertEquals(Resources.createResource(4096, 2),
-        jobRunParameters.getWorkerResource());
-    Assert.assertTrue(SubmarineLogs.isVerbose());
-    Assert.assertTrue(jobRunParameters.isWaitJobFinish());
-  }
-
-  @Test
-  public void testNoInputPathOptionSpecified() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
-    String actualMessage = "";
-    try {
-      runJobCli.run(
-          new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-              "--checkpoint_path", "hdfs://output",
-              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
-              "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
-              "true", "--verbose", "--wait_job_finish"});
-    } catch (ParseException e) {
-      actualMessage = e.getMessage();
-      e.printStackTrace();
-    }
-    assertEquals(expectedErrorMessage, actualMessage);
-  }
-
-  /**
-   * when only run tensorboard, input_path is not needed
-   * */
-  @Test
-  public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    boolean success = true;
-    try {
-      runJobCli.run(
-          new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-              "--num_workers", "0", "--tensorboard", "--verbose",
-              "--tensorboard_resources", "memory=2G,vcores=2",
-              "--tensorboard_docker_image", "tb_docker_image:001"});
-    } catch (ParseException e) {
-      success = false;
-    }
-    Assert.assertTrue(success);
-  }
-
-  @Test
-  public void testJobWithoutName() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    String expectedErrorMessage =
-        "--" + CliConstants.NAME + " is absent";
-    String actualMessage = "";
-    try {
-      runJobCli.run(
-          new String[]{"--docker_image", "tf-docker:1.1.0",
-              "--num_workers", "0", "--tensorboard", "--verbose",
-              "--tensorboard_resources", "memory=2G,vcores=2",
-              "--tensorboard_docker_image", "tb_docker_image:001"});
-    } catch (ParseException e) {
-      actualMessage = e.getMessage();
-      e.printStackTrace();
-    }
-    assertEquals(expectedErrorMessage, actualMessage);
-  }
-
-  @Test
-  public void testLaunchCommandPatternReplace() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    runJobCli.run(
-        new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
-            "--input_path", "hdfs://input", "--checkpoint_path",
-            "hdfs://output",
-            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
-            "python run-job.py --input=%input_path% " +
-                "--model_dir=%checkpoint_path% " +
-                "--export_dir=%saved_model_path%/savedmodel",
-            "--worker_resources", "memory=2048,vcores=2", "--ps_resources",
-            "memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
-            "python run-ps.py --input=%input_path% " +
-                "--model_dir=%checkpoint_path%/model",
-            "--verbose" });
-
-    assertEquals(
-        "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
-            + "--export_dir=hdfs://output/savedmodel",
-        runJobCli.getRunJobParameters().getWorkerLaunchCmd());
-    assertEquals(
-        "python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
-        runJobCli.getRunJobParameters().getPSLaunchCmd());
-  }
-}

+ 6 - 5
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/YamlConfigTestUtils.java

@@ -17,7 +17,7 @@
 package org.apache.hadoop.yarn.submarine.client.cli;
 
 import org.apache.commons.io.FileUtils;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
 import org.yaml.snakeyaml.Yaml;
 import org.yaml.snakeyaml.constructor.Constructor;
@@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
 
   private YamlConfigTestUtils() {}
 
-  static void deleteFile(File file) {
+  public static void deleteFile(File file) {
     if (file != null) {
       file.delete();
     }
   }
 
-  static YamlConfigFile readYamlConfigFile(String filename) {
+  public static YamlConfigFile readYamlConfigFile(String filename) {
     Constructor constructor = new Constructor(YamlConfigFile.class);
     constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
     Yaml yaml = new Yaml(constructor);
@@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
     return yaml.loadAs(inputStream, YamlConfigFile.class);
   }
 
-  static File createTempFileWithContents(String filename) throws IOException {
+  public static File createTempFileWithContents(String filename)
+      throws IOException {
     InputStream inputStream = YamlConfigTestUtils.class
         .getClassLoader()
         .getResourceAsStream(filename);
@@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
     return targetFile;
   }
 
-  static File createEmptyTempFile() throws IOException {
+  public static File createEmptyTempFile() throws IOException {
     return File.createTempFile("test", ".yaml");
   }
 

+ 129 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommon.java

@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.commons.cli.MissingArgumentException;
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import java.io.IOException;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * This class contains some test methods to test common functionality
+ * (including TF / PyTorch) of the run job Submarine command.
+ */
+public class TestRunJobCliParsingCommon {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  public static MockClientContext getMockClientContext()
+      throws IOException, YarnException {
+    MockClientContext mockClientContext = new MockClientContext();
+    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
+    when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
+        .thenReturn(ApplicationId.newInstance(1235L, 1));
+
+    JobMonitor mockJobMonitor = mock(JobMonitor.class);
+    SubmarineStorage storage = mock(SubmarineStorage.class);
+    RuntimeFactory rtFactory = mock(RuntimeFactory.class);
+
+    when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
+    when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
+    when(rtFactory.getSubmarineStorage()).thenReturn(storage);
+
+    mockClientContext.setRuntimeFactory(rtFactory);
+    return mockClientContext;
+  }
+
+  @Test
+  public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+    RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
+    assertTrue("Default Framework should be TensorFlow!",
+        runJobParameters instanceof TensorFlowRunJobParameters);
+  }
+
+  @Test
+  public void testEmptyFrameworkOption() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(MissingArgumentException.class);
+    expectedException.expectMessage("Missing argument for option: framework");
+
+    runJobCli.run(
+        new String[]{"--framework", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+  }
+
+  @Test
+  public void testInvalidFrameworkOption() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("Failed to parse Framework type");
+
+    runJobCli.run(
+        new String[]{"--framework", "bla", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish"});
+  }
+}

+ 252 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingCommonYaml.java

@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * This class contains some test methods to test common YAML parsing
+ * functionality (including TF / PyTorch) of the run job Submarine command.
+ */
+public class TestRunJobCliParsingCommonYaml {
+  private static final String DIR_NAME = "runjob-common-yaml";
+  private static final String TF_DIR = "runjob-pytorch-yaml";
+  private File yamlConfig;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @After
+  public void after() {
+    YamlConfigTestUtils.deleteFile(yamlConfig);
+  }
+
+  @BeforeClass
+  public static void configureResourceTypes() {
+    List<ResourceTypeInfo> resTypes = new ArrayList<>(
+        ResourceUtils.getResourcesTypeInfo());
+    resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
+    ResourceUtils.reinitializeResources(resTypes);
+  }
+
+  @Rule
+  public ExpectedException exception = ExpectedException.none();
+
+  @Test
+  public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        TF_DIR + "/valid-config.yaml");
+    String[] args = new String[] {"--name", "my-job",
+        "--docker_image", "tf-docker:1.1.0",
+        "-f", yamlConfig.getAbsolutePath() };
+
+    exception.expect(YarnException.class);
+    exception.expectMessage("defined both with YAML config and with " +
+        "CLI argument");
+
+    runJobCli.run(args);
+  }
+
+  @Test
+  public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
+      throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        TF_DIR + "/valid-config.yaml");
+    String[] args = new String[] {"--name", "my-job",
+        "--quicklink", "AAA=http://master-0:8321",
+        "--quicklink", "BBB=http://worker-0:1234",
+        "-f", yamlConfig.getAbsolutePath()};
+
+    exception.expect(YarnException.class);
+    exception.expectMessage("defined both with YAML config and with " +
+        "CLI argument");
+
+    runJobCli.run(args);
+  }
+
+  @Test
+  public void testFalseValuesForBooleanFields() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/test-false-values.yaml");
+    runJobCli.run(
+        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertFalse(jobRunParameters.isDistributeKeytab());
+    assertFalse(jobRunParameters.isWaitJobFinish());
+    assertFalse(tensorFlowParams.isTensorboardEnabled());
+  }
+
+  @Test
+  public void testWrongIndentation() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/wrong-indentation.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("Failed to parse YAML config, details:");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testWrongFilename() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    exception.expect(YamlParseException.class);
+    runJobCli.run(
+        new String[]{"-f", "not-existing", "--verbose"});
+  }
+
+  @Test
+  public void testEmptyFile() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
+
+    exception.expect(YamlParseException.class);
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testNotExistingFile() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("file does not exist");
+    runJobCli.run(
+        new String[]{"-f", "blabla", "--verbose"});
+  }
+
+  @Test
+  public void testWrongPropertyName() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/wrong-property-name.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("Failed to parse YAML config, details:");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testMissingConfigsSection() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/missing-configs.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("config section should be defined, " +
+        "but it cannot be found");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testMissingSectionsShouldParsed() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/some-sections-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+
+  @Test
+  public void testAbsentFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/missing-framework.yaml");
+
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testEmptyFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/empty-framework.yaml");
+
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testInvalidFramework() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-framework.yaml");
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("framework should is defined, " +
+        "but it has an invalid value");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+}

+ 192 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingParameterized.java

@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob;
+import com.google.common.collect.Lists;
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * This class contains some test methods to test common CLI parsing
+ * functionality (including TF / PyTorch) of the run job Submarine command.
+ */
+@RunWith(Parameterized.class)
+public class TestRunJobCliParsingParameterized {
+
+  private final Framework framework;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Parameterized.Parameters
+  public static Collection<Object[]> data() {
+    Collection<Object[]> params = new ArrayList<>();
+    params.add(new Object[]{Framework.TENSORFLOW });
+    params.add(new Object[]{Framework.PYTORCH });
+    return params;
+  }
+
+  public TestRunJobCliParsingParameterized(Framework framework) {
+    this.framework = framework;
+  }
+
+  private String getFrameworkName() {
+    return framework.getValue();
+  }
+
+  @Test
+  public void testPrintHelp() {
+    MockClientContext mockClientContext = new MockClientContext();
+    JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
+    JobMonitor mockJobMonitor = mock(JobMonitor.class);
+    RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
+        mockJobMonitor);
+    runJobCli.printUsages();
+  }
+
+  @Test
+  public void testNoInputPathOptionSpecified() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
+        " is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", getFrameworkName(),
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--checkpoint_path", "hdfs://output",
+              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+              "--worker_resources", "memory=4g,vcores=2", "--verbose",
+              "--wait_job_finish"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testJobWithoutName() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage =
+        "--" + CliConstants.NAME + " is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", getFrameworkName(),
+              "--docker_image", "tf-docker:1.1.0",
+              "--num_workers", "0", "--verbose"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testLaunchCommandPatternReplace() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    List<String> parameters = Lists.newArrayList("--framework",
+        getFrameworkName(),
+        "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+        "--input_path", "hdfs://input", "--checkpoint_path",
+        "hdfs://output",
+        "--num_workers", "3",
+        "--worker_launch_cmd", "python run-job.py --input=%input_path% " +
+            "--model_dir=%checkpoint_path% " +
+            "--export_dir=%saved_model_path%/savedmodel",
+        "--worker_resources", "memory=2048,vcores=2");
+
+    if (framework == Framework.TENSORFLOW) {
+      parameters.addAll(Lists.newArrayList(
+          "--ps_resources", "memory=4096,vcores=4",
+          "--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
+              "--model_dir=%checkpoint_path%/model",
+          "--verbose"));
+    }
+
+    runJobCli.run(parameters.toArray(new String[0]));
+
+    RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
+
+    if (framework == Framework.TENSORFLOW) {
+      TensorFlowRunJobParameters tensorFlowParams =
+          (TensorFlowRunJobParameters) runJobParameters;
+      assertEquals(
+          "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+              + "--export_dir=hdfs://output/savedmodel",
+          tensorFlowParams.getWorkerLaunchCmd());
+      assertEquals(
+          "python run-ps.py --input=hdfs://input " +
+              "--model_dir=hdfs://output/model",
+          tensorFlowParams.getPSLaunchCmd());
+    } else if (framework == Framework.PYTORCH) {
+      PyTorchRunJobParameters pyTorchParameters =
+          (PyTorchRunJobParameters) runJobParameters;
+      assertEquals(
+          "python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+              + "--export_dir=hdfs://output/savedmodel",
+          pyTorchParameters.getWorkerLaunchCmd());
+    }
+  }
+
+  private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
+    RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
+
+    if (framework == Framework.TENSORFLOW) {
+      assertTrue(RunJobParameters.class + " must be an instance of " +
+              TensorFlowRunJobParameters.class,
+          runJobParameters instanceof TensorFlowRunJobParameters);
+    } else if (framework == Framework.PYTORCH) {
+      assertTrue(RunJobParameters.class + " must be an instance of " +
+              PyTorchRunJobParameters.class,
+          runJobParameters instanceof PyTorchRunJobParameters);
+    }
+    return runJobParameters;
+  }
+
+}

+ 209 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorch.java

@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.Resources;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test class that verifies the correctness of PyTorch
+ * CLI configuration parsing.
+ */
+public class TestRunJobCliParsingPyTorch {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testBasicRunJobForSingleNodeTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--verbose",
+            "--wait_job_finish" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+            " must be an instance of " +
+            PyTorchRunJobParameters.class,
+        jobRunParameters instanceof PyTorchRunJobParameters);
+    PyTorchRunJobParameters pyTorchParams =
+        (PyTorchRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(pyTorchParams.getNumWorkers(), 1);
+    assertEquals(pyTorchParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(4096, 2),
+        pyTorchParams.getWorkerResource());
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+  }
+
+  @Test
+  public void testNumPSCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path","hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--num_ps", "2" });
+  }
+
+  @Test
+  public void testPSResourcesCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_resources", "memory=2048M,vcores=2" });
+  }
+
+  @Test
+  public void testPSDockerImageCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_docker_image", "psDockerImage" });
+  }
+
+  @Test
+  public void testPSLaunchCommandCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_launch_cmd", "psLaunchCommand" });
+  }
+
+  @Test
+  public void testTensorboardCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard" });
+  }
+
+  @Test
+  public void testTensorboardResourcesCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard_resources", "memory=2048M,vcores=2" });
+  }
+
+  @Test
+  public void testTensorboardDockerImageCannotBeDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    expectedException.expect(ParseException.class);
+    expectedException.expectMessage("cannot be defined for PyTorch jobs");
+    runJobCli.run(
+        new String[] {"--framework", "pytorch",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3",
+            "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--tensorboard_docker_image", "TBDockerImage" });
+  }
+
+}

+ 225 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/pytorch/TestRunJobCliParsingPyTorchYaml.java

@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
+import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+/**
+ * Test class that verifies the correctness of PyTorch
+ * YAML configuration parsing.
+ */
+public class TestRunJobCliParsingPyTorchYaml {
+  private static final String OVERRIDDEN_PREFIX = "overridden_";
+  private static final String DIR_NAME = "runjob-pytorch-yaml";
+  private File yamlConfig;
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @After
+  public void after() {
+    YamlConfigTestUtils.deleteFile(yamlConfig);
+  }
+
+  @BeforeClass
+  public static void configureResourceTypes() {
+    List<ResourceTypeInfo> resTypes = new ArrayList<>(
+        ResourceUtils.getResourcesTypeInfo());
+    resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
+    ResourceUtils.reinitializeResources(resTypes);
+  }
+
+  @Rule
+  public ExpectedException exception = ExpectedException.none();
+
+  private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
+    verifyBasicConfigValues(jobRunParameters,
+        ImmutableList.of("env1=env1Value", "env2=env2Value"));
+  }
+
+  private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
+      List<String> expectedEnvs) {
+    assertEquals("testInputPath", jobRunParameters.getInputPath());
+    assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
+    assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
+
+    assertNotNull(jobRunParameters.getLocalizations());
+    assertEquals(2, jobRunParameters.getLocalizations().size());
+
+    assertNotNull(jobRunParameters.getQuicklinks());
+    assertEquals(2, jobRunParameters.getQuicklinks().size());
+
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+
+    for (String env : expectedEnvs) {
+      assertTrue(String.format(
+          "%s should be in env list of jobRunParameters!", env),
+          jobRunParameters.getEnvars().contains(env));
+    }
+  }
+
+  private void verifyWorkerValues(RunJobParameters jobRunParameters,
+      String prefix) {
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            PyTorchRunJobParameters.class,
+        jobRunParameters instanceof PyTorchRunJobParameters);
+    PyTorchRunJobParameters tensorFlowParams =
+        (PyTorchRunJobParameters) jobRunParameters;
+
+    assertEquals(3, tensorFlowParams.getNumWorkers());
+    assertEquals(prefix + "testLaunchCmdWorker",
+        tensorFlowParams.getWorkerLaunchCmd());
+    assertEquals(prefix + "testDockerImageWorker",
+        tensorFlowParams.getWorkerDockerImage());
+    assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
+        ImmutableMap.<String, String> builder()
+            .put(ResourceInformation.GPU_URI, "2").build()),
+        tensorFlowParams.getWorkerResource());
+  }
+
+  private void verifySecurityValues(RunJobParameters jobRunParameters) {
+    assertEquals("keytabPath", jobRunParameters.getKeytab());
+    assertEquals("testPrincipal", jobRunParameters.getPrincipal());
+    assertTrue(jobRunParameters.isDistributeKeytab());
+  }
+
+  @Test
+  public void testValidYamlParsing() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/valid-config.yaml");
+    runJobCli.run(
+        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, "");
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testRoleOverrides() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    Assert.assertFalse(SubmarineLogs.isVerbose());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/valid-config-with-overrides.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testMissingPrincipalUnderSecuritySection() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/security-principal-is-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters);
+    verifyWorkerValues(jobRunParameters, "");
+
+    //Verify security values
+    assertEquals("keytabPath", jobRunParameters.getKeytab());
+    assertNull("Principal should be null!", jobRunParameters.getPrincipal());
+    assertTrue(jobRunParameters.isDistributeKeytab());
+  }
+
+  @Test
+  public void testMissingEnvs() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/envs-are-missing.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
+    verifyWorkerValues(jobRunParameters, "");
+    verifySecurityValues(jobRunParameters);
+  }
+
+  @Test
+  public void testInvalidConfigPsSectionIsDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("PS section should not be defined " +
+        "when PyTorch is the selected framework");
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-config-ps-section.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+  @Test
+  public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    exception.expect(YamlParseException.class);
+    exception.expectMessage("TensorBoard section should not be defined " +
+        "when PyTorch is the selected framework");
+    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
+        DIR_NAME + "/invalid-config-tensorboard-section.yaml");
+    runJobCli.run(
+        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
+  }
+
+}

+ 170 - 0
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlow.java

@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
+
+import org.apache.commons.cli.ParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.util.resource.Resources;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test class that verifies the correctness of TensorFlow
+ * CLI configuration parsing.
+ */
+public class TestRunJobCliParsingTensorFlow {
+
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testNoInputPathOptionSpecified() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
+        "\" is absent";
+    String actualMessage = "";
+    try {
+      runJobCli.run(
+          new String[]{"--framework", "tensorflow",
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--checkpoint_path", "hdfs://output",
+              "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+              "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+              "true", "--verbose", "--wait_job_finish"});
+    } catch (ParseException e) {
+      actualMessage = e.getMessage();
+      e.printStackTrace();
+    }
+    assertEquals(expectedErrorMessage, actualMessage);
+  }
+
+  @Test
+  public void testBasicRunJobForDistributedTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] { "--framework", "tensorflow",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input",
+            "--checkpoint_path", "hdfs://output",
+            "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
+            "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
+            "--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
+            "--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
+            "--principal", "user/_HOST@domain.com", "--distribute_keytab",
+            "--verbose" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+        " must be an instance of " +
+        TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(tensorFlowParams.getNumPS(), 2);
+    assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
+    assertEquals(Resources.createResource(4096, 4),
+        tensorFlowParams.getPsResource());
+    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(2048, 2),
+        tensorFlowParams.getWorkerResource());
+    assertEquals(jobRunParameters.getDockerImageName(),
+        "tf-docker:1.1.0");
+    assertEquals(jobRunParameters.getKeytab(),
+        "/keytab/path");
+    assertEquals(jobRunParameters.getPrincipal(),
+        "user/_HOST@domain.com");
+    assertTrue(jobRunParameters.isDistributeKeytab());
+    assertTrue(SubmarineLogs.isVerbose());
+  }
+
+  @Test
+  public void testBasicRunJobForSingleNodeTraining() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    assertFalse(SubmarineLogs.isVerbose());
+
+    runJobCli.run(
+        new String[] { "--framework", "tensorflow",
+            "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+            "--input_path", "hdfs://input", "--checkpoint_path",
+            "hdfs://output",
+            "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
+            "--worker_resources", "memory=4g,vcores=2", "--tensorboard",
+            "true", "--verbose", "--wait_job_finish" });
+
+    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+    assertTrue(RunJobParameters.class +
+            " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
+    assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
+    assertEquals(tensorFlowParams.getNumWorkers(), 1);
+    assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
+        "python run-job.py");
+    assertEquals(Resources.createResource(4096, 2),
+        tensorFlowParams.getWorkerResource());
+    assertTrue(SubmarineLogs.isVerbose());
+    assertTrue(jobRunParameters.isWaitJobFinish());
+  }
+
+  /**
+   * when only run tensorboard, input_path is not needed
+   * */
+  @Test
+  public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
+    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+    boolean success = true;
+    try {
+      runJobCli.run(
+          new String[]{"--framework", "tensorflow",
+              "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+              "--num_workers", "0", "--tensorboard", "--verbose",
+              "--tensorboard_resources", "memory=2G,vcores=2",
+              "--tensorboard_docker_image", "tb_docker_image:001"});
+    } catch (ParseException e) {
+      success = false;
+    }
+    assertTrue(success);
+  }
+}

+ 45 - 159
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsingYaml.java → hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYaml.java

@@ -15,16 +15,17 @@
  */
 
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
-import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
+import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.junit.After;
@@ -39,19 +40,18 @@ import java.io.File;
 import java.util.ArrayList;
 import java.util.List;
 
-import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext;
+import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
 /**
- * Test class that verifies the correctness of YAML configuration parsing.
+ * Test class that verifies the correctness of TF YAML configuration parsing.
  */
-public class TestRunJobCliParsingYaml {
+public class TestRunJobCliParsingTensorFlowYaml {
   private static final String OVERRIDDEN_PREFIX = "overridden_";
-  private static final String DIR_NAME = "runjobcliparsing";
+  private static final String DIR_NAME = "runjob-tensorflow-yaml";
   private File yamlConfig;
 
   @Before
@@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
 
   private void verifyPsValues(RunJobParameters jobRunParameters,
       String prefix) {
-    assertEquals(4, jobRunParameters.getNumPS());
-    assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(4, tensorFlowParams.getNumPS());
+    assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
     assertEquals(prefix + "testDockerImagePs",
-        jobRunParameters.getPsDockerImage());
+        tensorFlowParams.getPsDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "4").build()),
-        jobRunParameters.getPsResource());
+        tensorFlowParams.getPsResource());
   }
 
   private void verifyWorkerValues(RunJobParameters jobRunParameters,
       String prefix) {
-    assertEquals(3, jobRunParameters.getNumWorkers());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertEquals(3, tensorFlowParams.getNumWorkers());
     assertEquals(prefix + "testLaunchCmdWorker",
-        jobRunParameters.getWorkerLaunchCmd());
+        tensorFlowParams.getWorkerLaunchCmd());
     assertEquals(prefix + "testDockerImageWorker",
-        jobRunParameters.getWorkerDockerImage());
+        tensorFlowParams.getWorkerDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "2").build()),
-        jobRunParameters.getWorkerResource());
+        tensorFlowParams.getWorkerResource());
   }
 
   private void verifySecurityValues(RunJobParameters jobRunParameters) {
@@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
   }
 
   private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
-    assertTrue(jobRunParameters.isTensorboardEnabled());
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertTrue(tensorFlowParams.isTensorboardEnabled());
     assertEquals("tensorboardDockerImage",
-        jobRunParameters.getTensorboardDockerImage());
+        tensorFlowParams.getTensorboardDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "3").build()),
-        jobRunParameters.getTensorboardResource());
+        tensorFlowParams.getTensorboardResource());
   }
 
   @Test
@@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
     verifyTensorboardValues(jobRunParameters);
   }
 
-  @Test
-  public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/valid-config.yaml");
-    String[] args = new String[] {"--name", "my-job",
-        "--docker_image", "tf-docker:1.1.0",
-        "-f", yamlConfig.getAbsolutePath() };
-
-    exception.expect(YarnException.class);
-    exception.expectMessage("defined both with YAML config and with " +
-        "CLI argument");
-
-    runJobCli.run(args);
-  }
-
-  @Test
-  public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
-      throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/valid-config.yaml");
-    String[] args = new String[] {"--name", "my-job",
-        "--quicklink", "AAA=http://master-0:8321",
-        "--quicklink", "BBB=http://worker-0:1234",
-        "-f", yamlConfig.getAbsolutePath()};
-
-    exception.expect(YarnException.class);
-    exception.expectMessage("defined both with YAML config and with " +
-        "CLI argument");
-
-    runJobCli.run(args);
-  }
-
   @Test
   public void testRoleOverrides() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
     verifyTensorboardValues(jobRunParameters);
   }
 
-  @Test
-  public void testFalseValuesForBooleanFields() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/test-false-values.yaml");
-    runJobCli.run(
-        new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-    RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
-
-    assertFalse(jobRunParameters.isDistributeKeytab());
-    assertFalse(jobRunParameters.isWaitJobFinish());
-    assertFalse(jobRunParameters.isTensorboardEnabled());
-  }
-
-  @Test
-  public void testWrongIndentation() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/wrong-indentation.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("Failed to parse YAML config, details:");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testWrongFilename() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-    Assert.assertFalse(SubmarineLogs.isVerbose());
-
-    exception.expect(YamlParseException.class);
-    runJobCli.run(
-        new String[]{"-f", "not-existing", "--verbose"});
-  }
-
-  @Test
-  public void testEmptyFile() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
-
-    exception.expect(YamlParseException.class);
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testNotExistingFile() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("file does not exist");
-    runJobCli.run(
-        new String[]{"-f", "blabla", "--verbose"});
-  }
-
-  @Test
-  public void testWrongPropertyName() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/wrong-property-name.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("Failed to parse YAML config, details:");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testMissingConfigsSection() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/missing-configs.yaml");
-
-    exception.expect(YamlParseException.class);
-    exception.expectMessage("config section should be defined, " +
-        "but it cannot be found");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
-  @Test
-  public void testMissingSectionsShouldParsed() throws Exception {
-    RunJobCli runJobCli = new RunJobCli(getMockClientContext());
-
-    yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
-        DIR_NAME + "/some-sections-missing.yaml");
-    runJobCli.run(
-        new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
-  }
-
   @Test
   public void testMissingPrincipalUnderSecuritySection() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
         new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
 
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
     verifyBasicConfigValues(jobRunParameters);
     verifyPsValues(jobRunParameters, "");
     verifyWorkerValues(jobRunParameters, "");
     verifySecurityValues(jobRunParameters);
 
-    assertTrue(jobRunParameters.isTensorboardEnabled());
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
+    assertTrue(tensorFlowParams.isTensorboardEnabled());
     assertNull("tensorboardDockerImage should be null!",
-        jobRunParameters.getTensorboardDockerImage());
+        tensorFlowParams.getTensorboardDockerImage());
     assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
         ImmutableMap.<String, String> builder()
             .put(ResourceInformation.GPU_URI, "3").build()),
-        jobRunParameters.getTensorboardResource());
+        tensorFlowParams.getTensorboardResource());
   }
 
   @Test

+ 8 - 9
hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsingYamlStandalone.java → hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYamlStandalone.java

@@ -15,7 +15,7 @@
  */
 
 
-package org.apache.hadoop.yarn.submarine.client.cli;
+package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
 
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
 import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
@@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
  * Please note that this class just tests YAML parsing,
  * but only in an isolated fashion.
  */
-public class TestRunJobCliParsingYamlStandalone {
+public class TestRunJobCliParsingTensorFlowYamlStandalone {
   private static final String OVERRIDDEN_PREFIX = "overridden_";
-  private static final String DIR_NAME = "runjobcliparsing";
-
-  @Before
-  public void before() {
-    SubmarineLogs.verboseOff();
-  }
+  private static final String DIR_NAME = "runjob-tensorflow-yaml";
 
   private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
     assertNotNull("Spec file should not be null!", yamlConfigFile);
@@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
     assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
   }
 
+  @Before
+  public void before() {
+    SubmarineLogs.verboseOff();
+  }
+
   @Test
   public void testLaunchCommandYaml() {
     YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
@@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
     assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
     assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
   }
-
 }

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/empty-framework.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework:
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/invalid-framework.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: bla
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/missing-configs.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-configs.yaml


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/valid-config.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/missing-framework.yaml


+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/some-sections-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/some-sections-missing.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/test-false-values.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/test-false-values.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/wrong-indentation.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-indentation.yaml


+ 0 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/wrong-property-name.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-common-yaml/wrong-property-name.yaml


+ 51 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/envs-are-missing.yaml

@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 56 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-ps-section.yaml

@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    docker_image: testPSDockerImage
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 57 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/invalid-config-tensorboard-section.yaml

@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  docker_image: tensorboardDockerImage

+ 53 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/security-principal-is-missing.yaml

@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  distribute_keytab: true

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config-with-overrides.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: overridden_testLaunchCmdWorker
+    docker_image: overridden_testDockerImageWorker
+    envs:
+      env1: 'overridden_env1Worker'
+      env2: 'overridden_env2Worker'
+    localizations:
+    - hdfs://remote-file1:/overridden_local-filename1Worker:rw
+    - nfs://remote-file2:/overridden_local-filename2Worker:rw
+    mounts:
+    - /etc/passwd:/overridden_Worker
+    - /etc/hosts:/overridden_Worker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 54 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-pytorch-yaml/valid-config.yaml

@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: pytorch
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/envs-are-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/envs-are-missing.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/security-principal-is-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/security-principal-is-missing.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/tensorboard-dockerimage-is-missing.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/tensorboard-dockerimage-is-missing.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 1 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjobcliparsing/valid-config-with-overrides.yaml → hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config-with-overrides.yaml

@@ -17,6 +17,7 @@
 spec:
   name: testJobName
   job_type: testJobType
+  framework: tensorflow
 
 configs:
   input_path: testInputPath

+ 63 - 0
hadoop-submarine/hadoop-submarine-core/src/test/resources/runjob-tensorflow-yaml/valid-config.yaml

@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+spec:
+  name: testJobName
+  job_type: testJobType
+  framework: tensorflow
+
+configs:
+  input_path: testInputPath
+  checkpoint_path: testCheckpointPath
+  saved_model_path: testSavedModelPath
+  docker_image: testDockerImage
+  wait_job_finish: true
+  envs:
+    env1: 'env1Value'
+    env2: 'env2Value'
+  localizations:
+  - hdfs://remote-file1:/local-filename1:rw
+  - nfs://remote-file2:/local-filename2:rw
+  mounts:
+  - /etc/passwd:/etc/passwd:rw
+  - /etc/hosts:/etc/hosts:rw
+  quicklinks:
+  - Notebook_UI=https://master-0:7070
+  - Notebook_UI2=https://master-0:7071
+
+scheduling:
+  queue: queue1
+
+roles:
+  worker:
+    resources: memory=20480M,vcores=32,gpu=2
+    replicas: 3
+    launch_cmd: testLaunchCmdWorker
+    docker_image: testDockerImageWorker
+  ps:
+    resources: memory=20500M,vcores=34,gpu=4
+    replicas: 4
+    launch_cmd: testLaunchCmdPs
+    docker_image: testDockerImagePs
+
+security:
+  keytab: keytabPath
+  principal: testPrincipal
+  distribute_keytab: true
+
+tensorBoard:
+  resources: memory=21000M,vcores=37,gpu=3
+  docker_image: tensorboardDockerImage

+ 17 - 5
hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyJobSubmitter.java

@@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
 
 import java.io.File;
@@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
   }
 
   @Override
-  public ApplicationId submitJob(RunJobParameters parameters)
-      throws IOException, YarnException {
+  public ApplicationId submitJob(ParametersHolder parameters)
+      throws IOException {
+    if (parameters.getFramework() == Framework.PYTORCH) {
+      // we need to throw an exception, as ParametersHolder's parameters field
+      // could not be casted to TensorFlowRunJobParameters.
+      throw new UnsupportedOperationException(
+          "Support \"–-framework\" option for PyTorch in Tony is coming. " +
+              "Please check the documentation about how to submit a " +
+              "PyTorch job with TonY runtime.");
+    }
+
     LOG.info("Starting Tony runtime..");
 
     File tonyFinalConfPath = File.createTempFile("temp",
         Constants.TONY_FINAL_XML);
     // Write user's overridden conf to an xml to be localized.
-    Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters);
+    Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
+        (TensorFlowRunJobParameters) parameters.getParameters());
     try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
       tonyConf.writeXml(os);
     } catch (IOException e) {
@@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
       LOG.error("Failed to init TonyClient: ", e);
     }
     Thread clientThread = new Thread(tonyClient::start);
-    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
+    java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
       try {
         tonyClient.forceKillApplication();
       } catch (YarnException | IOException e) {

+ 2 - 2
hadoop-submarine/hadoop-submarine-tony-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/tony/TonyUtils.java

@@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ResourceInformation;
 import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -35,7 +35,7 @@ public final class TonyUtils {
   private static final Log LOG = LogFactory.getLog(TonyUtils.class);
 
   public static Configuration tonyConfFromClientContext(
-      RunJobParameters parameters) {
+      TensorFlowRunJobParameters parameters) {
     Configuration tonyConf = new Configuration();
     tonyConf.setInt(
         TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),

+ 4 - 0
hadoop-submarine/hadoop-submarine-tony-runtime/src/site/markdown/QuickStart.md

@@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --num_workers 2 \
  --worker_resources memory=3G,vcores=2 \
  --num_ps 2 \
@@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --worker_resources memory=3G,vcores=2 \
@@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --num_workers 2 \
  --worker_resources memory=3G,vcores=2 \
  --num_ps 2 \
@@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
 /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
 
 java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
+ --framework tensorflow \
  --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
  --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
  --worker_resources memory=3G,vcores=2 \

+ 19 - 7
hadoop-submarine/hadoop-submarine-tony-runtime/src/test/java/TestTonyUtils.java

@@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.MockClientContext;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
@@ -31,6 +33,7 @@ import org.junit.Test;
 
 import java.io.IOException;
 
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -59,7 +62,8 @@ public class TestTonyUtils {
       throws IOException, YarnException {
     MockClientContext mockClientContext = new MockClientContext();
     JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
-    when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
+    when(mockJobSubmitter.submitJob(
+        any(ParametersHolder.class))).thenReturn(
         ApplicationId.newInstance(1234L, 1));
     JobMonitor mockJobMonitor = mock(JobMonitor.class);
     SubmarineStorage storage = mock(SubmarineStorage.class);
@@ -82,20 +86,28 @@ public class TestTonyUtils {
   public void testTonyConfFromClientContext() throws Exception {
     RunJobCli runJobCli = new RunJobCli(getMockClientContext());
     runJobCli.run(
-        new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
+        new String[] {"--framework", "tensorflow", "--name", "my-job",
+            "--docker_image", "tf-docker:1.1.0",
             "--input_path", "hdfs://input",
             "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
             "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
             "--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
             "python run-ps.py"});
     RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
+
+    assertTrue(RunJobParameters.class + " must be an instance of " +
+            TensorFlowRunJobParameters.class,
+        jobRunParameters instanceof TensorFlowRunJobParameters);
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) jobRunParameters;
+
     Configuration tonyConf = TonyUtils
-        .tonyConfFromClientContext(jobRunParameters);
+        .tonyConfFromClientContext(tensorFlowParams);
     Assert.assertEquals(jobRunParameters.getDockerImageName(),
         tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
     Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
         .getInstancesKey("worker")));
-    Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(),
+    Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
         tonyConf.get(TonyConfigurationKeys
             .getExecuteCommandKey("worker")));
     Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
@@ -107,7 +119,7 @@ public class TestTonyUtils {
     Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
         .getResourceKey(Constants.PS_JOB_NAME,
         Constants.VCORES)));
-    Assert.assertEquals(jobRunParameters.getPSLaunchCmd(),
+    Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
         tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
   }
 }

+ 49 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java

@@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
@@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
 import java.io.IOException;
 import java.util.Objects;
 
+import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
 
 /**
  * Abstract base class for Component classes.
@@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
 public abstract class AbstractComponent {
   private final FileSystemOperations fsOperations;
   protected final RunJobParameters parameters;
-  protected final TaskType taskType;
+  protected final Role role;
   private final RemoteDirectoryManager remoteDirectoryManager;
   protected final Configuration yarnConfig;
   private final LaunchCommandFactory launchCommandFactory;
@@ -52,19 +58,55 @@ public abstract class AbstractComponent {
 
   public AbstractComponent(FileSystemOperations fsOperations,
       RemoteDirectoryManager remoteDirectoryManager,
-      RunJobParameters parameters, TaskType taskType,
+      RunJobParameters parameters, Role role,
       Configuration yarnConfig,
       LaunchCommandFactory launchCommandFactory) {
     this.fsOperations = fsOperations;
     this.remoteDirectoryManager = remoteDirectoryManager;
     this.parameters = parameters;
-    this.taskType = taskType;
+    this.role = role;
     this.launchCommandFactory = launchCommandFactory;
     this.yarnConfig = yarnConfig;
   }
 
   protected abstract Component createComponent() throws IOException;
 
+  protected Component createComponentInternal() throws IOException {
+    Objects.requireNonNull(this.parameters.getWorkerResource(),
+        "Worker resource must not be null!");
+    if (parameters.getNumWorkers() < 1) {
+      throw new IllegalArgumentException(
+          "Number of workers should be at least 1!");
+    }
+
+    Component component = new Component();
+    component.setName(role.getComponentName());
+
+    if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
+        role.equals(PyTorchRole.PRIMARY_WORKER)) {
+      component.setNumberOfContainers(1L);
+      component.getConfiguration().setProperty(
+          CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
+    } else {
+      component.setNumberOfContainers(
+          (long) parameters.getNumWorkers() - 1);
+    }
+
+    if (parameters.getWorkerDockerImage() != null) {
+      component.setArtifact(
+          getDockerArtifact(parameters.getWorkerDockerImage()));
+    }
+
+    component.setResource(convertYarnResourceToServiceResource(
+        parameters.getWorkerResource()));
+    component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
+
+    addCommonEnvironments(component, role);
+    generateLaunchCommand(component);
+
+    return component;
+  }
+
   /**
    * Generates a command launch script on local disk,
    * returns path to the script.
@@ -72,7 +114,7 @@ public abstract class AbstractComponent {
   protected void generateLaunchCommand(Component component)
       throws IOException {
     AbstractLaunchCommand launchCommand =
-        launchCommandFactory.createLaunchCommand(taskType, component);
+        launchCommandFactory.createLaunchCommand(role, component);
     this.localScriptFile = launchCommand.generateLaunchScript();
 
     String remoteLaunchCommand = uploadLaunchCommand(component);
@@ -86,7 +128,7 @@ public abstract class AbstractComponent {
     Path stagingDir =
         remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
 
-    String destScriptFileName = getScriptFileName(taskType);
+    String destScriptFileName = getScriptFileName(role);
     fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
         localScriptFile, destScriptFileName, component);
 

+ 167 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractServiceSpec.java

@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
+
+/**
+ * Abstract base class that supports creating service specs for Native Service.
+ */
+public abstract class AbstractServiceSpec implements ServiceSpec {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(AbstractServiceSpec.class);
+  protected final RunJobParameters parameters;
+  protected final FileSystemOperations fsOperations;
+  private final Localizer localizer;
+  protected final RemoteDirectoryManager remoteDirectoryManager;
+  protected final Configuration yarnConfig;
+  protected final LaunchCommandFactory launchCommandFactory;
+  private final WorkerComponentFactory workerFactory;
+
+  public AbstractServiceSpec(RunJobParameters parameters,
+      ClientContext clientContext, FileSystemOperations fsOperations,
+      LaunchCommandFactory launchCommandFactory,
+      Localizer localizer) {
+    this.parameters = parameters;
+    this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
+    this.yarnConfig = clientContext.getYarnConfig();
+    this.fsOperations = fsOperations;
+    this.localizer = localizer;
+    this.launchCommandFactory = launchCommandFactory;
+    this.workerFactory = new WorkerComponentFactory(fsOperations,
+        remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
+  }
+
+  protected ServiceWrapper createServiceSpecWrapper() throws IOException {
+    Service serviceSpec = new Service();
+    serviceSpec.setName(parameters.getName());
+    serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
+    serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
+
+    KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
+        .create(fsOperations, remoteDirectoryManager, parameters);
+    if (kerberosPrincipal != null) {
+      serviceSpec.setKerberosPrincipal(kerberosPrincipal);
+    }
+
+    handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
+    localizer.handleLocalizations(serviceSpec);
+    return new ServiceWrapper(serviceSpec);
+  }
+
+
+  // Handle worker and primary_worker.
+  protected void addWorkerComponents(ServiceWrapper serviceWrapper,
+      Framework framework)
+      throws IOException {
+    final Role primaryWorkerRole;
+    final Role workerRole;
+    if (framework == Framework.TENSORFLOW) {
+      primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
+      workerRole = TensorFlowRole.WORKER;
+    } else {
+      primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
+      workerRole = PyTorchRole.WORKER;
+    }
+
+    addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
+
+    if (parameters.getNumWorkers() > 1) {
+      addWorkerComponent(serviceWrapper, workerRole, framework);
+    }
+  }
+  private void addWorkerComponent(ServiceWrapper serviceWrapper,
+      Role role, Framework framework) throws IOException {
+    AbstractComponent component = workerFactory.create(framework, role);
+    serviceWrapper.addComponent(component);
+  }
+
+  protected void handleQuicklinks(Service serviceSpec)
+      throws IOException {
+    List<Quicklink> quicklinks = parameters.getQuicklinks();
+    if (quicklinks != null && !quicklinks.isEmpty()) {
+      for (Quicklink ql : quicklinks) {
+        // Make sure it is a valid instance name
+        String instanceName = ql.getComponentInstanceName();
+        boolean found = false;
+
+        for (Component comp : serviceSpec.getComponents()) {
+          for (int i = 0; i < comp.getNumberOfContainers(); i++) {
+            String possibleInstanceName = comp.getName() + "-" + i;
+            if (possibleInstanceName.equals(instanceName)) {
+              found = true;
+              break;
+            }
+          }
+        }
+
+        if (!found) {
+          throw new IOException(
+              "Couldn't find a component instance = " + instanceName
+                  + " while adding quicklink");
+        }
+
+        String link = ql.getProtocol()
+            + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
+            getUserName(), getDNSDomain(yarnConfig), ql.getPort());
+        addQuicklink(serviceSpec, ql.getLabel(), link);
+      }
+    }
+  }
+
+  protected static void addQuicklink(Service serviceSpec, String label,
+      String link) {
+    Map<String, String> quicklinks = serviceSpec.getQuicklinks();
+    if (quicklinks == null) {
+      quicklinks = new HashMap<>();
+      serviceSpec.setQuicklinks(quicklinks);
+    }
+
+    if (SubmarineLogs.isVerbose()) {
+      LOG.info("Added quicklink, " + label + "=" + link);
+    }
+
+    quicklinks.put(label, link);
+  }
+}

+ 10 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java

@@ -37,6 +37,7 @@ import java.io.File;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -195,6 +196,15 @@ public class FileSystemOperations {
     fs.setPermission(destPath, new FsPermission(permission));
   }
 
+  public static boolean needHdfs(List<String> stringsToCheck) {
+    for (String content : stringsToCheck) {
+      if (content != null && content.contains("hdfs://")) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public static boolean needHdfs(String content) {
     return content != null && content.contains("hdfs://");
   }

+ 20 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java

@@ -16,9 +16,10 @@
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
 
+import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
@@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
 import java.io.File;
 import java.io.IOException;
 import java.io.PrintWriter;
+import java.util.List;
+import java.util.Objects;
 
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
 import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
@@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
   }
 
   private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
-    return needHdfs(parameters.getInputPath()) ||
-        needHdfs(parameters.getPSLaunchCmd()) ||
-        needHdfs(parameters.getWorkerLaunchCmd()) ||
-        hadoopEnv;
+    List<String> launchCommands = parameters.getLaunchCommands();
+    if (launchCommands != null) {
+      launchCommands.removeIf(Objects::isNull);
+    }
+
+    ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
+
+    if (launchCommands != null && !launchCommands.isEmpty()) {
+      listBuilder.addAll(launchCommands);
+    }
+    if (parameters.getInputPath() != null) {
+      listBuilder.add(parameters.getInputPath());
+    }
+    List<String> stringsToCheck = listBuilder.build();
+
+    return needHdfs(stringsToCheck) || hadoopEnv;
   }
 
   private void appendHdfsHome(PrintWriter fw, String hdfsHome) {

+ 1 - 1
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java

@@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
         "instantiated!");
   }
 
-  static String generateJson(Service service) throws IOException {
+  public static String generateJson(Service service) throws IOException {
     File serviceSpecFile = File.createTempFile(service.getName(), ".json");
     String buffer = jsonSerDeser.toJson(service);
     Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),

+ 71 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/WorkerComponentFactory.java

@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
+
+/**
+ * Factory class that helps creating Native Service components.
+ */
+public class WorkerComponentFactory {
+  private final FileSystemOperations fsOperations;
+  private final RemoteDirectoryManager remoteDirectoryManager;
+  private final RunJobParameters parameters;
+  private final LaunchCommandFactory launchCommandFactory;
+  private final Configuration yarnConfig;
+
+  WorkerComponentFactory(FileSystemOperations fsOperations,
+      RemoteDirectoryManager remoteDirectoryManager,
+      RunJobParameters parameters,
+      LaunchCommandFactory launchCommandFactory,
+      Configuration yarnConfig) {
+    this.fsOperations = fsOperations;
+    this.remoteDirectoryManager = remoteDirectoryManager;
+    this.parameters = parameters;
+    this.launchCommandFactory = launchCommandFactory;
+    this.yarnConfig = yarnConfig;
+  }
+
+  /**
+   * Creates either a TensorFlow or a PyTorch Native Service component.
+   */
+  public AbstractComponent create(Framework framework, Role role) {
+    if (framework == Framework.TENSORFLOW) {
+      return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
+          (TensorFlowRunJobParameters) parameters, role,
+          (TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
+    } else if (framework == Framework.PYTORCH) {
+      return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
+          (PyTorchRunJobParameters) parameters, role,
+          (PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
+    } else {
+      throw new UnsupportedOperationException("Only supported frameworks are: "
+          + Framework.getValues());
+    }
+  }
+}

+ 62 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java

@@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
 import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.service.api.records.Service;
 import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
 import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.slf4j.Logger;
@@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 
 import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
+import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
 
 /**
  * Submit a job to cluster.
@@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
    * {@inheritDoc}
    */
   @Override
-  public ApplicationId submitJob(RunJobParameters parameters)
+  public ApplicationId submitJob(ParametersHolder paramsHolder)
       throws IOException, YarnException {
+    Framework framework = paramsHolder.getFramework();
+    RunJobParameters parameters =
+        (RunJobParameters) paramsHolder.getParameters();
+
+    if (framework == Framework.TENSORFLOW) {
+      return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
+    } else if (framework == Framework.PYTORCH) {
+      return submitPyTorchJob((PyTorchRunJobParameters) parameters);
+    } else {
+      throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
+    }
+  }
+
+  private ApplicationId submitTensorFlowJob(
+      TensorFlowRunJobParameters parameters) throws IOException, YarnException {
     FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
     HadoopEnvironmentSetup hadoopEnvSetup =
         new HadoopEnvironmentSetup(clientContext, fsOperations);
 
     Service serviceSpec = createTensorFlowServiceSpec(parameters,
         fsOperations, hadoopEnvSetup);
+    return submitJobInternal(serviceSpec);
+  }
+
+  private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
+      throws IOException, YarnException {
+    FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
+    HadoopEnvironmentSetup hadoopEnvSetup =
+        new HadoopEnvironmentSetup(clientContext, fsOperations);
+
+    Service serviceSpec = createPyTorchServiceSpec(parameters,
+        fsOperations, hadoopEnvSetup);
+    return submitJobInternal(serviceSpec);
+  }
+
+  private ApplicationId submitJobInternal(Service serviceSpec)
+      throws IOException, YarnException {
     String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
 
     AppAdminClient appAdminClient =
@@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
           "Fail to launch application with exit code:" + code);
     }
 
-    String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
+    String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
     Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
 
     // Retry multiple times if applicationId is null
@@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
     return appid;
   }
 
-  private Service createTensorFlowServiceSpec(RunJobParameters parameters,
+  private Service createTensorFlowServiceSpec(
+      TensorFlowRunJobParameters parameters,
       FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
       throws IOException {
-    LaunchCommandFactory launchCommandFactory =
-        new LaunchCommandFactory(hadoopEnvSetup, parameters,
+    TensorFlowLaunchCommandFactory launchCommandFactory =
+        new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
             clientContext.getYarnConfig());
     Localizer localizer = new Localizer(fsOperations,
         clientContext.getRemoteDirectoryManager(), parameters);
@@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
     return serviceWrapper.getService();
   }
 
+  private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
+      FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
+      throws IOException {
+    PyTorchLaunchCommandFactory launchCommandFactory =
+        new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
+            clientContext.getYarnConfig());
+    Localizer localizer = new Localizer(fsOperations,
+        clientContext.getRemoteDirectoryManager(), parameters);
+    PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
+        parameters, this.clientContext, fsOperations, launchCommandFactory,
+        localizer);
+
+    serviceWrapper = pyTorchServiceSpec.create();
+    return serviceWrapper.getService();
+  }
+
   @VisibleForTesting
   public ServiceWrapper getServiceWrapper() {
     return serviceWrapper;

+ 4 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java

@@ -17,11 +17,9 @@
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import java.io.IOException;
-import java.util.Objects;
 
 /**
  * Abstract base class for Launch command implementations for Services.
@@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
   private final LaunchScriptBuilder builder;
 
   public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters)
-      throws IOException {
-    Objects.requireNonNull(taskType, "TaskType must not be null!");
-    this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
+      Component component, RunJobParameters parameters,
+      String launchCommandPrefix) throws IOException {
+    this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
         parameters, component);
   }
 

+ 5 - 42
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java

@@ -16,52 +16,15 @@
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
-import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 
 import java.io.IOException;
-import java.util.Objects;
 
 /**
- * Simple factory to create instances of {@link AbstractLaunchCommand}
- * based on the {@link TaskType}.
- * All dependencies are passed to this factory that could be required
- * by any implementor of {@link AbstractLaunchCommand}.
+ * Interface for creating launch commands.
  */
-public class LaunchCommandFactory {
-  private final HadoopEnvironmentSetup hadoopEnvSetup;
-  private final RunJobParameters parameters;
-  private final Configuration yarnConfig;
-
-  public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
-      RunJobParameters parameters, Configuration yarnConfig) {
-    this.hadoopEnvSetup = hadoopEnvSetup;
-    this.parameters = parameters;
-    this.yarnConfig = yarnConfig;
-  }
-
-  public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
-      Component component) throws IOException {
-    Objects.requireNonNull(taskType, "TaskType must not be null!");
-
-    if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
-      return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
-          component, parameters, yarnConfig);
-
-    } else if (taskType == TaskType.PS) {
-      return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
-          parameters, yarnConfig);
-
-    } else if (taskType == TaskType.TENSORBOARD) {
-      return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
-          parameters);
-    }
-    throw new IllegalStateException("Unknown task type: " + taskType);
-  }
+public interface LaunchCommandFactory {
+  AbstractLaunchCommand createLaunchCommand(Role role, Component component)
+      throws IOException;
 }

+ 4 - 3
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java

@@ -17,7 +17,7 @@
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
 
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
   private final StringBuilder scriptBuffer;
   private String launchCommand;
 
-  LaunchScriptBuilder(String namePrefix,
+  LaunchScriptBuilder(String launchScriptPrefix,
       HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
       Component component) throws IOException {
-    this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
+    this.file = File.createTempFile(launchScriptPrefix +
+        "-launch-script", ".sh");
     this.hadoopEnvSetup = hadoopEnvSetup;
     this.parameters = parameters;
     this.component = component;

+ 61 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/PyTorchLaunchCommandFactory.java

@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
+
+/**
+ * Simple factory to create instances of {@link AbstractLaunchCommand}
+ * based on the {@link Role}.
+ * All dependencies are passed to this factory that could be required
+ * by any implementor of {@link AbstractLaunchCommand}.
+ */
+public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
+  private final HadoopEnvironmentSetup hadoopEnvSetup;
+  private final PyTorchRunJobParameters parameters;
+  private final Configuration yarnConfig;
+
+  public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
+      PyTorchRunJobParameters parameters, Configuration yarnConfig) {
+    this.hadoopEnvSetup = hadoopEnvSetup;
+    this.parameters = parameters;
+    this.yarnConfig = yarnConfig;
+  }
+
+  public AbstractLaunchCommand createLaunchCommand(Role role,
+      Component component) throws IOException {
+    Objects.requireNonNull(role, "Role must not be null!");
+
+    if (role == PyTorchRole.WORKER ||
+        role == PyTorchRole.PRIMARY_WORKER) {
+      return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
+          component, parameters, yarnConfig);
+
+    } else {
+      throw new IllegalStateException("Unknown task type: " + role);
+    }
+  }
+}

+ 70 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TensorFlowLaunchCommandFactory.java

@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Simple factory to create instances of {@link AbstractLaunchCommand}
+ * based on the {@link Role}.
+ * All dependencies are passed to this factory that could be required
+ * by any implementor of {@link AbstractLaunchCommand}.
+ */
+public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
+  private final HadoopEnvironmentSetup hadoopEnvSetup;
+  private final TensorFlowRunJobParameters parameters;
+  private final Configuration yarnConfig;
+
+  public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
+      TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
+    this.hadoopEnvSetup = hadoopEnvSetup;
+    this.parameters = parameters;
+    this.yarnConfig = yarnConfig;
+  }
+
+  @Override
+  public AbstractLaunchCommand createLaunchCommand(Role role,
+      Component component) throws IOException {
+    Objects.requireNonNull(role, "Role must not be null!");
+
+    if (role == TensorFlowRole.WORKER ||
+        role == TensorFlowRole.PRIMARY_WORKER) {
+      return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
+          component, parameters, yarnConfig);
+
+    } else if (role == TensorFlowRole.PS) {
+      return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
+          parameters, yarnConfig);
+
+    } else if (role == TensorFlowRole.TENSORBOARD) {
+      return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
+          parameters);
+    }
+    throw new IllegalStateException("Unknown task type: " + role);
+  }
+}

+ 68 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/PyTorchServiceSpec.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
+
+import java.io.IOException;
+
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class contains all the logic to create an instance
+ * of a {@link Service} object for PyTorch.
+ * Please note that currently, only single-node (non-distributed)
+ * support is implemented for PyTorch.
+ */
+public class PyTorchServiceSpec extends AbstractServiceSpec {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(PyTorchServiceSpec.class);
+  //this field is needed in the future!
+  private final PyTorchRunJobParameters pyTorchParameters;
+
+  public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
+      ClientContext clientContext, FileSystemOperations fsOperations,
+      PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
+    super(parameters, clientContext, fsOperations, launchCommandFactory,
+        localizer);
+    this.pyTorchParameters = parameters;
+  }
+
+  @Override
+  public ServiceWrapper create() throws IOException {
+    LOG.info("Creating PyTorch service spec");
+    ServiceWrapper serviceWrapper = createServiceSpecWrapper();
+
+    if (parameters.getNumWorkers() > 0) {
+      addWorkerComponents(serviceWrapper, Framework.PYTORCH);
+    }
+
+    // After all components added, handle quicklinks
+    handleQuicklinks(serviceWrapper.getService());
+
+    return serviceWrapper;
+  }
+
+}

+ 87 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/PyTorchWorkerLaunchCommand.java

@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
+
+import java.io.IOException;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Launch command implementation for PyTorch components.
+ */
+public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
+  private final Configuration yarnConfig;
+  private final boolean distributed;
+  private final int numberOfWorkers;
+  private final String name;
+  private final Role role;
+  private final String launchCommand;
+
+  public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+      Role role, Component component,
+      PyTorchRunJobParameters parameters,
+      Configuration yarnConfig) throws IOException {
+    super(hadoopEnvSetup, component, parameters, role.getName());
+    this.role = role;
+    this.name = parameters.getName();
+    this.distributed = parameters.isDistributed();
+    this.numberOfWorkers = parameters.getNumWorkers();
+    this.yarnConfig = yarnConfig;
+    logReceivedParameters();
+
+    this.launchCommand = parameters.getWorkerLaunchCmd();
+
+    if (StringUtils.isEmpty(this.launchCommand)) {
+      throw new IllegalArgumentException("LaunchCommand must not be null " +
+          "or empty!");
+    }
+  }
+
+  private void logReceivedParameters() {
+    if (this.numberOfWorkers <= 0) {
+      LOG.warn("Received number of workers: {}", this.numberOfWorkers);
+    }
+  }
+
+  @Override
+  public String generateLaunchScript() throws IOException {
+    LaunchScriptBuilder builder = getBuilder();
+    return builder
+        .withLaunchCommand(createLaunchCommand())
+        .build();
+  }
+
+  @Override
+  public String createLaunchCommand() {
+    if (SubmarineLogs.isVerbose()) {
+      LOG.info("PyTorch Worker command =[" + launchCommand + "]");
+    }
+    return launchCommand + '\n';
+  }
+}

+ 19 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/command/package-info.java

@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate PyTorch launch commands.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;

+ 47 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/PyTorchWorkerComponent.java

@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
+
+import java.io.IOException;
+
+/**
+ * Component implementation for Worker process of PyTorch.
+ */
+public class PyTorchWorkerComponent extends AbstractComponent {
+  public PyTorchWorkerComponent(FileSystemOperations fsOperations,
+      RemoteDirectoryManager remoteDirectoryManager,
+      PyTorchRunJobParameters parameters, Role role,
+      PyTorchLaunchCommandFactory launchCommandFactory,
+      Configuration yarnConfig) {
+    super(fsOperations, remoteDirectoryManager, parameters, role,
+        yarnConfig, launchCommandFactory);
+  }
+
+  @Override
+  public Component createComponent() throws IOException {
+    return createComponentInternal();
+  }
+}

+ 20 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/component/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * PyTorch Native Service components.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;

+ 20 - 0
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/pytorch/package-info.java

@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * PyTorch-related Native Service runtime artifacts.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;

+ 5 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java

@@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.submarine.common.Envs;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
 
 import java.util.Map;
@@ -35,10 +35,10 @@ public final class TensorFlowCommons {
   }
 
   public static void addCommonEnvironments(Component component,
-      TaskType taskType) {
+      Role role) {
     Map<String, String> envs = component.getConfiguration().getEnv();
     envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
-    envs.put(Envs.TASK_TYPE_ENV, taskType.name());
+    envs.put(Envs.TASK_TYPE_ENV, role.getName());
   }
 
   public static String getUserName() {
@@ -49,8 +49,8 @@ public final class TensorFlowCommons {
     return yarnConfig.get("hadoop.registry.dns.domain-name");
   }
 
-  public static String getScriptFileName(TaskType taskType) {
-    return "run-" + taskType.name() + ".sh";
+  public static String getScriptFileName(Role role) {
+    return "run-" + role.getName() + ".sh";
   }
 
   public static String getTFConfigEnv(String componentName, int nWorkers,

+ 22 - 125
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java

@@ -16,39 +16,24 @@
 
 package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
 
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
 import org.apache.hadoop.yarn.service.api.records.Service;
-import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
 import org.apache.hadoop.yarn.submarine.common.ClientContext;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
-import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
 import org.apache.hadoop.yarn.submarine.utils.Localizer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 
-import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
-import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
 import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
-import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
-import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
 
 /**
  * This class contains all the logic to create an instance
@@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
  * Worker,PS and Tensorboard components are added to the Service
  * based on the value of the received {@link RunJobParameters}.
  */
-public class TensorFlowServiceSpec implements ServiceSpec {
+public class TensorFlowServiceSpec extends AbstractServiceSpec {
   private static final Logger LOG =
       LoggerFactory.getLogger(TensorFlowServiceSpec.class);
+  private final TensorFlowRunJobParameters tensorFlowParameters;
 
-  private final RemoteDirectoryManager remoteDirectoryManager;
-
-  private final RunJobParameters parameters;
-  private final Configuration yarnConfig;
-  private final FileSystemOperations fsOperations;
-  private final LaunchCommandFactory launchCommandFactory;
-  private final Localizer localizer;
-
-  public TensorFlowServiceSpec(RunJobParameters parameters,
+  public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
       ClientContext clientContext, FileSystemOperations fsOperations,
-      LaunchCommandFactory launchCommandFactory, Localizer localizer) {
-    this.parameters = parameters;
-    this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
-    this.yarnConfig = clientContext.getYarnConfig();
-    this.fsOperations = fsOperations;
-    this.launchCommandFactory = launchCommandFactory;
-    this.localizer = localizer;
+      TensorFlowLaunchCommandFactory launchCommandFactory,
+      Localizer localizer) {
+    super(parameters, clientContext, fsOperations, launchCommandFactory,
+        localizer);
+    this.tensorFlowParameters = parameters;
   }
 
   @Override
   public ServiceWrapper create() throws IOException {
+    LOG.info("Creating TensorFlow service spec");
     ServiceWrapper serviceWrapper = createServiceSpecWrapper();
 
-    if (parameters.getNumWorkers() > 0) {
-      addWorkerComponents(serviceWrapper);
+    if (tensorFlowParameters.getNumWorkers() > 0) {
+      addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
     }
 
-    if (parameters.getNumPS() > 0) {
+    if (tensorFlowParameters.getNumPS() > 0) {
       addPsComponent(serviceWrapper);
     }
 
-    if (parameters.isTensorboardEnabled()) {
+    if (tensorFlowParameters.isTensorboardEnabled()) {
       createTensorBoardComponent(serviceWrapper);
     }
 
@@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
     return serviceWrapper;
   }
 
-  private ServiceWrapper createServiceSpecWrapper() throws IOException {
-    Service serviceSpec = new Service();
-    serviceSpec.setName(parameters.getName());
-    serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
-    serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
-
-    KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
-        .create(fsOperations, remoteDirectoryManager, parameters);
-    if (kerberosPrincipal != null) {
-      serviceSpec.setKerberosPrincipal(kerberosPrincipal);
-    }
-
-    handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
-    localizer.handleLocalizations(serviceSpec);
-    return new ServiceWrapper(serviceSpec);
-  }
-
   private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
       throws IOException {
     TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
-        remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
+        remoteDirectoryManager, parameters,
+        (TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
     serviceWrapper.addComponent(tbComponent);
 
     addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
         tbComponent.getTensorboardLink());
   }
 
-  private static void addQuicklink(Service serviceSpec, String label,
-      String link) {
-    Map<String, String> quicklinks = serviceSpec.getQuicklinks();
-    if (quicklinks == null) {
-      quicklinks = new HashMap<>();
-      serviceSpec.setQuicklinks(quicklinks);
-    }
-
-    if (SubmarineLogs.isVerbose()) {
-      LOG.info("Added quicklink, " + label + "=" + link);
-    }
-
-    quicklinks.put(label, link);
-  }
-
-  private void handleQuicklinks(Service serviceSpec)
-      throws IOException {
-    List<Quicklink> quicklinks = parameters.getQuicklinks();
-    if (quicklinks != null && !quicklinks.isEmpty()) {
-      for (Quicklink ql : quicklinks) {
-        // Make sure it is a valid instance name
-        String instanceName = ql.getComponentInstanceName();
-        boolean found = false;
-
-        for (Component comp : serviceSpec.getComponents()) {
-          for (int i = 0; i < comp.getNumberOfContainers(); i++) {
-            String possibleInstanceName = comp.getName() + "-" + i;
-            if (possibleInstanceName.equals(instanceName)) {
-              found = true;
-              break;
-            }
-          }
-        }
-
-        if (!found) {
-          throw new IOException(
-              "Couldn't find a component instance = " + instanceName
-                  + " while adding quicklink");
-        }
-
-        String link = ql.getProtocol()
-            + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
-                getUserName(), getDNSDomain(yarnConfig), ql.getPort());
-        addQuicklink(serviceSpec, ql.getLabel(), link);
-      }
-    }
-  }
-
-  // Handle worker and primary_worker.
-
-  private void addWorkerComponents(ServiceWrapper serviceWrapper)
-      throws IOException {
-    addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
-
-    if (parameters.getNumWorkers() > 1) {
-      addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
-    }
-  }
-  private void addWorkerComponent(ServiceWrapper serviceWrapper,
-      RunJobParameters parameters, TaskType taskType) throws IOException {
-    serviceWrapper.addComponent(
-        new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
-        parameters, taskType, launchCommandFactory, yarnConfig));
-  }
-
   private void addPsComponent(ServiceWrapper serviceWrapper)
       throws IOException {
     serviceWrapper.addComponent(
         new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
-            launchCommandFactory, parameters, yarnConfig));
+            (TensorFlowLaunchCommandFactory) launchCommandFactory,
+            parameters, yarnConfig));
   }
 
 }

+ 4 - 4
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java

@@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.slf4j.Logger;
@@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
   private final String checkpointPath;
 
   public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters)
+      Role role, Component component, RunJobParameters parameters)
       throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters);
+    super(hadoopEnvSetup, component, parameters, role.getName());
     Objects.requireNonNull(parameters.getCheckpointPath(),
         "CheckpointPath must not be null as it is part "
             + "of the tensorboard command!");

+ 11 - 7
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java

@@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
@@ -28,6 +28,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.Objects;
 
 /**
  * Launch command implementation for
@@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
   private final int numberOfWorkers;
   private final int numberOfPS;
   private final String name;
-  private final TaskType taskType;
+  private final Role role;
 
   TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters,
+      Role role, Component component,
+      TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters);
-    this.taskType = taskType;
+    super(hadoopEnvSetup, component, parameters,
+        role != null ? role.getName(): "");
+    Objects.requireNonNull(role, "TensorFlowRole must not be null!");
+    this.role = role;
     this.name = parameters.getName();
     this.distributed = parameters.isDistributed();
     this.numberOfWorkers = parameters.getNumWorkers();
@@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
     // When distributed training is required
     if (distributed) {
       String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
-          taskType.getComponentName(), numberOfWorkers,
+          role.getComponentName(), numberOfWorkers,
           numberOfPS, name,
           TensorFlowCommons.getUserName(),
           TensorFlowCommons.getDNSDomain(yarnConfig));

+ 5 - 4
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java

@@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
@@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
   private final String launchCommand;
 
   public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
-      TaskType taskType, Component component, RunJobParameters parameters,
+      Role role, Component component,
+      TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+    super(hadoopEnvSetup, role, component, parameters, yarnConfig);
     this.launchCommand = parameters.getPSLaunchCmd();
 
     if (StringUtils.isEmpty(this.launchCommand)) {

+ 5 - 5
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java

@@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.Role;
 import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
 import org.slf4j.Logger;
@@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
   private final String launchCommand;
 
   public TensorFlowWorkerLaunchCommand(
-      HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
-      Component component, RunJobParameters parameters,
+      HadoopEnvironmentSetup hadoopEnvSetup, Role role,
+      Component component, TensorFlowRunJobParameters parameters,
       Configuration yarnConfig) throws IOException {
-    super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+    super(hadoopEnvSetup, role, component, parameters, yarnConfig);
     this.launchCommand = parameters.getWorkerLaunchCmd();
 
     if (StringUtils.isEmpty(this.launchCommand)) {

+ 16 - 12
hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java

@@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
-import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
 import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
 import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
   public TensorBoardComponent(FileSystemOperations fsOperations,
       RemoteDirectoryManager remoteDirectoryManager,
       RunJobParameters parameters,
-      LaunchCommandFactory launchCommandFactory,
+      TensorFlowLaunchCommandFactory launchCommandFactory,
       Configuration yarnConfig) {
     super(fsOperations, remoteDirectoryManager, parameters,
-        TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
+        TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
   }
 
   @Override
   public Component createComponent() throws IOException {
-    Objects.requireNonNull(parameters.getTensorboardResource(),
+    TensorFlowRunJobParameters tensorFlowParams =
+        (TensorFlowRunJobParameters) this.parameters;
+
+    Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
         "TensorBoard resource must not be null!");
 
     Component component = new Component();
-    component.setName(taskType.getComponentName());
+    component.setName(role.getComponentName());
     component.setNumberOfContainers(1L);
     component.setRestartPolicy(RestartPolicyEnum.NEVER);
     component.setResource(convertYarnResourceToServiceResource(
-        parameters.getTensorboardResource()));
+        tensorFlowParams.getTensorboardResource()));
 
-    if (parameters.getTensorboardDockerImage() != null) {
+    if (tensorFlowParams.getTensorboardDockerImage() != null) {
       component.setArtifact(
-          getDockerArtifact(parameters.getTensorboardDockerImage()));
+          getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
     }
 
-    addCommonEnvironments(component, taskType);
+    addCommonEnvironments(component, role);
     generateLaunchCommand(component);
 
     tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
         parameters.getName(),
-        taskType.getComponentName() + "-" + 0, getUserName(),
+        role.getComponentName() + "-" + 0, getUserName(),
         getDNSDomain(yarnConfig), DEFAULT_PORT);
     LOG.info("Link to tensorboard:" + tensorboardLink);
 

Some files were not shown because too many files changed in this diff