001/* 002 * Copyright (C) 2012 eXo Platform SAS. 003 * 004 * This is free software; you can redistribute it and/or modify it 005 * under the terms of the GNU Lesser General Public License as 006 * published by the Free Software Foundation; either version 2.1 of 007 * the License, or (at your option) any later version. 008 * 009 * This software is distributed in the hope that it will be useful, 010 * but WITHOUT ANY WARRANTY; without even the implied warranty of 011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 012 * Lesser General Public License for more details. 013 * 014 * You should have received a copy of the GNU Lesser General Public 015 * License along with this software; if not, write to the Free 016 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 017 * 02110-1301 USA, or see the FSF site: http://www.fsf.org. 018 */ 019package org.crsh.ssh.term; 020 021import org.apache.sshd.SshServer; 022import org.apache.sshd.common.KeyPairProvider; 023import org.apache.sshd.common.NamedFactory; 024import org.apache.sshd.common.Session; 025import org.apache.sshd.server.Command; 026import org.apache.sshd.server.PasswordAuthenticator; 027import org.apache.sshd.server.PublickeyAuthenticator; 028import org.apache.sshd.server.session.ServerSession; 029import org.crsh.plugin.PluginContext; 030import org.crsh.auth.AuthenticationPlugin; 031import org.crsh.ssh.term.scp.SCPCommandFactory; 032import org.crsh.ssh.term.subsystem.SubsystemFactoryPlugin; 033import org.crsh.term.TermLifeCycle; 034import org.crsh.term.spi.TermIOHandler; 035 036import java.security.PublicKey; 037import java.util.ArrayList; 038import java.util.List; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * Interesting stuff here : http://gerrit.googlecode.com/git-history/4b9e5e7fb9380cfadd28d7ffe3dc496dc06f5892/gerrit-sshd/src/main/java/com/google/gerrit/sshd/DatabasePubKeyAuth.java 044 */ 045public class SSHLifeCycle extends TermLifeCycle { 046 047 /** . */ 048 public static final Session.AttributeKey<String> USERNAME = new Session.AttributeKey<java.lang.String>(); 049 050 /** . */ 051 public static final Session.AttributeKey<String> PASSWORD = new Session.AttributeKey<java.lang.String>(); 052 053 /** . */ 054 private final Logger log = Logger.getLogger(SSHLifeCycle.class.getName()); 055 056 /** . */ 057 private SshServer server; 058 059 /** . */ 060 private int port; 061 062 /** . */ 063 private KeyPairProvider keyPairProvider; 064 065 /** . */ 066 private final AuthenticationPlugin authentication; 067 068 /** . */ 069 private Integer localPort; 070 071 public SSHLifeCycle(PluginContext context, AuthenticationPlugin<?> authentication) { 072 super(context); 073 074 // 075 this.authentication = authentication; 076 } 077 078 public int getPort() { 079 return port; 080 } 081 082 public void setPort(int port) { 083 this.port = port; 084 } 085 086 /** 087 * Returns the local part after the ssh server has been succesfully bound or null. This is useful when 088 * the port is chosen at random by the system. 089 * 090 * @return the local port 091 */ 092 public Integer getLocalPort() { 093 return localPort; 094 } 095 096 public KeyPairProvider getKeyPairProvider() { 097 return keyPairProvider; 098 } 099 100 public void setKeyPairProvider(KeyPairProvider keyPairProvider) { 101 this.keyPairProvider = keyPairProvider; 102 } 103 104 @Override 105 protected void doInit() { 106 try { 107 108 // 109 TermIOHandler handler = getHandler(); 110 111 // 112 SshServer server = SshServer.setUpDefaultServer(); 113 server.setPort(port); 114 server.setShellFactory(new CRaSHCommandFactory(handler)); 115 server.setCommandFactory(new SCPCommandFactory(getContext())); 116 server.setKeyPairProvider(keyPairProvider); 117 118 // 119 ArrayList<NamedFactory<Command>> namedFactoryList = new ArrayList<NamedFactory<Command>>(0); 120 for (SubsystemFactoryPlugin plugin : getContext().getPlugins(SubsystemFactoryPlugin.class)) { 121 namedFactoryList.add(plugin.getFactory()); 122 } 123 server.setSubsystemFactories(namedFactoryList); 124 125 // 126 if (authentication.getCredentialType().equals(String.class)) { 127 @SuppressWarnings("unchecked") 128 final AuthenticationPlugin<String> passwordAuthentication = (AuthenticationPlugin<String>)authentication; 129 server.setPasswordAuthenticator(new PasswordAuthenticator() { 130 public boolean authenticate(String _username, String _password, ServerSession session) { 131 boolean auth; 132 try { 133 log.log(Level.FINE, "Using authentication plugin " + authentication + " to authenticate user " + _username); 134 auth = passwordAuthentication.authenticate(_username, _password); 135 } catch (Exception e) { 136 log.log(Level.SEVERE, "Exception authenticating user " + _username + " in authentication plugin: " + authentication, e); 137 return false; 138 } 139 140 // We store username and password in session for later reuse 141 session.setAttribute(USERNAME, _username); 142 session.setAttribute(PASSWORD, _password); 143 144 // 145 return auth; 146 } 147 }); 148 } else if (authentication.getCredentialType().equals(PublicKey.class)) { 149 @SuppressWarnings("unchecked") 150 final AuthenticationPlugin<PublicKey> keyAuthentication = (AuthenticationPlugin<PublicKey>)authentication; 151 server.setPublickeyAuthenticator(new PublickeyAuthenticator() { 152 public boolean authenticate(String username, PublicKey key, ServerSession session) { 153 try { 154 log.log(Level.FINE, "Using authentication plugin " + authentication + " to authenticate user " + username); 155 156 157 return keyAuthentication.authenticate(username, key); 158 } 159 catch (Exception e) { 160 log.log(Level.SEVERE, "Exception authenticating user " + username + " in authentication plugin: " + authentication, e); 161 return false; 162 } 163 } 164 }); 165 } 166 167 // 168 log.log(Level.INFO, "About to start CRaSSHD"); 169 server.start(); 170 localPort = server.getPort(); 171 log.log(Level.INFO, "CRaSSHD started on port " + localPort); 172 173 // 174 this.server = server; 175 } 176 catch (Throwable e) { 177 log.log(Level.SEVERE, "Could not start CRaSSHD", e); 178 } 179 } 180 181 @Override 182 protected void doDestroy() { 183 if (server != null) { 184 try { 185 server.stop(); 186 } 187 catch (InterruptedException e) { 188 log.log(Level.FINE, "Got an interruption when stopping server", e); 189 } 190 } 191 } 192}