001    /*
002     * Copyright (C) 2012 eXo Platform SAS.
003     *
004     * This is free software; you can redistribute it and/or modify it
005     * under the terms of the GNU Lesser General Public License as
006     * published by the Free Software Foundation; either version 2.1 of
007     * the License, or (at your option) any later version.
008     *
009     * This software is distributed in the hope that it will be useful,
010     * but WITHOUT ANY WARRANTY; without even the implied warranty of
011     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
012     * Lesser General Public License for more details.
013     *
014     * You should have received a copy of the GNU Lesser General Public
015     * License along with this software; if not, write to the Free
016     * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
017     * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
018     */
019    package org.crsh.ssh.term;
020    
021    import org.apache.sshd.SshServer;
022    import org.apache.sshd.common.Session;
023    import org.apache.sshd.server.PasswordAuthenticator;
024    import org.apache.sshd.server.PublickeyAuthenticator;
025    import org.apache.sshd.server.session.ServerSession;
026    import org.crsh.plugin.PluginContext;
027    import org.crsh.auth.AuthenticationPlugin;
028    import org.crsh.ssh.term.scp.SCPCommandFactory;
029    import org.crsh.term.TermLifeCycle;
030    import org.crsh.term.spi.TermIOHandler;
031    import org.crsh.vfs.Resource;
032    
033    import java.security.PublicKey;
034    import java.util.Set;
035    import java.util.logging.Level;
036    import java.util.logging.Logger;
037    
038    /**
039     * Interesting stuff here : http://gerrit.googlecode.com/git-history/4b9e5e7fb9380cfadd28d7ffe3dc496dc06f5892/gerrit-sshd/src/main/java/com/google/gerrit/sshd/DatabasePubKeyAuth.java
040     */
041    public class SSHLifeCycle extends TermLifeCycle {
042    
043      /** . */
044      public static final Session.AttributeKey<String> USERNAME = new Session.AttributeKey<java.lang.String>();
045    
046      /** . */
047      public static final Session.AttributeKey<String> PASSWORD = new Session.AttributeKey<java.lang.String>();
048    
049      /** . */
050      private final Logger log = Logger.getLogger(SSHLifeCycle.class.getName());
051    
052      /** . */
053      private SshServer server;
054    
055      /** . */
056      private int port;
057    
058      /** . */
059      private Resource key;
060    
061      /** . */
062      private final AuthenticationPlugin authentication;
063    
064      /** . */
065      private Integer localPort;
066    
067      /** . */
068      private final Set<PublicKey> authorizedKeys;
069    
070      public SSHLifeCycle(PluginContext context, AuthenticationPlugin authentication, Set<PublicKey> authorizedKeys) {
071        super(context);
072    
073        //
074        this.authentication = authentication;
075        this.authorizedKeys = authorizedKeys;
076      }
077    
078      public int getPort() {
079        return port;
080      }
081    
082      public void setPort(int port) {
083        this.port = port;
084      }
085    
086      /**
087       * Returns the local part after the ssh server has been succesfully bound or null. This is useful when
088       * the port is chosen at random by the system.
089       *
090       * @return the local port
091       */
092      public Integer getLocalPort() {
093              return localPort;
094      }
095      
096      public Resource getKey() {
097        return key;
098      }
099    
100      public void setKey(Resource key) {
101        this.key = key;
102      }
103    
104      @Override
105      protected void doInit() {
106        try {
107    
108          //
109          TermIOHandler handler = getHandler();
110    
111          //
112          SshServer server = SshServer.setUpDefaultServer();
113          server.setPort(port);
114          server.setShellFactory(new CRaSHCommandFactory(handler));
115          server.setCommandFactory(new SCPCommandFactory(getContext()));
116          server.setKeyPairProvider(new URLKeyPairProvider(key));
117    
118          //
119          if (authorizedKeys != null && authorizedKeys.size() > 0) {
120            server.setPublickeyAuthenticator(new PublickeyAuthenticator() {
121              public boolean authenticate(String username, PublicKey key, ServerSession session) {
122                if (authorizedKeys.contains(key)) {
123                  log.log(Level.FINE, "Authenticated " + username + " with public key " + key);
124                  return true;
125                } else {
126                  log.log(Level.FINE, "Denied " + username + " with public key " + key);
127                  return false;
128                }
129              }
130            });
131          }
132    
133          //
134          server.setPasswordAuthenticator(new PasswordAuthenticator() {
135            public boolean authenticate(String _username, String _password, ServerSession session) {
136              boolean auth;
137              try {
138                log.log(Level.FINE, "Using authentication plugin " + authentication + " to authenticate user " + _username);
139                auth = authentication.authenticate(_username, _password);
140              } catch (Exception e) {
141                log.log(Level.SEVERE, "Exception authenticating user " + _username + " in authentication plugin: " + authentication, e);
142                return false;
143              }
144    
145              // We store username and password in session for later reuse
146              session.setAttribute(USERNAME, _username);
147              session.setAttribute(PASSWORD, _password);
148    
149              //
150              return auth;
151            }
152          });
153    
154          //
155          log.log(Level.INFO, "About to start CRaSSHD");
156          server.start();
157          localPort = server.getPort();
158          log.log(Level.INFO, "CRaSSHD started on port " + localPort);
159    
160          //
161          this.server = server;
162        }
163        catch (Throwable e) {
164          log.log(Level.SEVERE, "Could not start CRaSSHD", e);
165        }
166      }
167    
168      @Override
169      protected void doDestroy() {
170        if (server != null) {
171          try {
172            server.stop();
173          }
174          catch (InterruptedException e) {
175            log.log(Level.FINE, "Got an interruption when stopping server", e);
176          }
177        }
178      }
179    }